diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index b6b802a7c6e..ebf23ba5cff 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -1,9 +1,7 @@ -ARG UBUNTU_VERSION=25.10 +ARG UBUNTU_VERSION=26.04 FROM ubuntu:$UBUNTU_VERSION AS build -# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html - # Install build tools RUN apt update && apt install -y git build-essential cmake wget xz-utils diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index be4e23bc813..a57d0e8b1cb 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -351,16 +351,10 @@ jobs: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server - name: Python setup @@ -374,13 +368,6 @@ jobs: run: | pip install -r tools/server/tests/requirements.txt - - name: Copy Libcurl - id: prepare_libcurl - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll - - name: Tests id: server_integration_tests if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }} diff --git a/CODEOWNERS b/CODEOWNERS index 908d13a35b9..6ef6c0489f1 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,10 +2,8 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren @CISC +/.github/actions/ @CISC /.github/workflows/ @CISC -/.github/workflows/release.yml @slaren -/.github/workflows/winget.yml @slaren /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov @@ -40,21 +38,14 @@ /examples/passkey/ @ggerganov /examples/retrieval/ @ggerganov /examples/save-load-state/ @ggerganov -/examples/simple-chat/ @slaren -/examples/simple/ @slaren /examples/speculative-simple/ @ggerganov /examples/speculative/ @ggerganov /ggml/cmake/ @ggerganov -/ggml/include/ @ggerganov @slaren -/ggml/src/ggml-alloc.c @slaren -/ggml/src/ggml-backend* @slaren -/ggml/src/ggml-blas/ @slaren -/ggml/src/ggml-common.h @ggerganov @slaren -/ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/include/ @ggerganov +/ggml/src/ggml-common.h @ggerganov +/ggml/src/ggml-cpu/ @ggerganov /ggml/src/ggml-cpu/spacemit/ @alex-spacemit -/ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler -/ggml/src/ggml-cuda/ggml-cuda.cu @slaren /ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler @@ -62,19 +53,19 @@ /ggml/src/ggml-cuda/fattn-wmma* @IMbackK /ggml/src/ggml-hip/ @IMbackK /ggml/src/ggml-cuda/vendors/hip.h @IMbackK -/ggml/src/ggml-impl.h @ggerganov @slaren +/ggml/src/ggml-impl.h @ggerganov /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov -/ggml/src/ggml-threading.* @ggerganov @slaren +/ggml/src/ggml-threading.* @ggerganov /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM -/ggml/src/ggml.c @ggerganov @slaren -/ggml/src/ggml.cpp @ggerganov @slaren +/ggml/src/ggml.c @ggerganov +/ggml/src/ggml.cpp @ggerganov /ggml/src/gguf.cpp @JohannesGaessler @Green-Sky /gguf-py/ @CISC /media/ @ggerganov @@ -86,15 +77,11 @@ /src/llama-arch.* @CISC /src/llama-chat.* @ngxson /src/llama-graph.* @CISC -/src/llama-model-loader.* @slaren /src/llama-model.* @CISC /src/llama-vocab.* @CISC /src/models/ @CISC /tests/ @ggerganov -/tests/test-backend-ops.cpp @slaren -/tests/test-thread-safety.cpp @slaren /tools/batched-bench/ @ggerganov -/tools/llama-bench/ @slaren /tools/main/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov @@ -106,8 +93,6 @@ /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov /vendor/ @ggerganov -/.clang-format @slaren -/.clang-tidy @slaren /AUTHORS @ggerganov /CMakeLists.txt @ggerganov /CONTRIBUTING.md @ggerganov diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe2..dd787290d25 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sparam()); add_opt(common_arg( @@ -1261,6 +1262,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP; } ).set_sparam()); add_opt(common_arg( @@ -1268,6 +1270,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { params.sampling.top_k = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } ).set_sparam()); add_opt(common_arg( @@ -1275,6 +1278,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; } ).set_sparam()); add_opt(common_arg( @@ -1282,6 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; } ).set_sparam()); add_opt(common_arg( @@ -1296,6 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; } ).set_sparam()); add_opt(common_arg( @@ -1303,6 +1309,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; } ).set_sparam()); add_opt(common_arg( @@ -1321,6 +1328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } params.sampling.penalty_last_n = value; params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N; } ).set_sparam()); add_opt(common_arg( @@ -1328,6 +1336,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; } ).set_sparam()); add_opt(common_arg( @@ -1425,6 +1434,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { params.sampling.mirostat = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT; } ).set_sparam()); add_opt(common_arg( @@ -1432,6 +1442,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; } ).set_sparam()); add_opt(common_arg( @@ -1439,6 +1450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; } ).set_sparam()); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index f3cc55247e7..0d7fd9a9371 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -949,6 +950,58 @@ std::vector fs_list_files(const std::string & path) { // Model utils // +static inline void common_init_sampler_from_model( + const llama_model * model, + common_params_sampling & sparams) { + + const uint64_t config = sparams.user_sampling_config; + + auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[64] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + int32_t v = strtol(buf, &end, 10); + if (end && end != buf) dst = v; + } + }; + + auto get_float = [&](const char * key, float & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[128] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + float v = strtof(buf, &end); + if (end && end != buf) dst = v; + } + }; + + // Sampling sequence + if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { + const std::vector sampler_names = string_split(std::string(buf), ';'); + if (!sampler_names.empty()) { + sparams.samplers = common_sampler_types_from_names(sampler_names, true); + } + } + } + + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); +} + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); @@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + common_init_sampler_from_model(model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index de5b404dd88..2f23d0baa83 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,22 @@ struct common_grammar_trigger { llama_token token = LLAMA_TOKEN_NULL; }; +enum common_params_sampling_config : uint64_t { + COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2, + COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5, + COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, +}; + + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -172,6 +188,8 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; + uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8743202ad6c..daaf0bf4974 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -565,7 +565,7 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, ) ) - or not new_name.endswith(".weight") + or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") ): data_qtype = gguf.GGMLQuantizationType.F32 @@ -4183,6 +4183,21 @@ def set_vocab(self): super().set_vocab() +@ModelBase.register("RND1") +class RND1Model(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.RND1 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # RND1 specific parameters + # RND1 uses bidirectional attention + self.gguf_writer.add_causal_attention(False) + + if (mask_token_id := self.hparams.get("mask_token_id")) is not None: + self.gguf_writer.add_mask_token_id(mask_token_id) + + @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): @@ -10046,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase): torch.uint8: np.uint8, } + # only used when byteswapping data. Only correct size is needed + _dtype_byteswap_map: dict[torch.dtype, type] = { + torch.float64: np.float64, + torch.float32: np.float32, + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int64: np.int64, + torch.uint64: np.uint64, + torch.int32: np.int32, + torch.uint32: np.uint32, + torch.int16: np.int16, + torch.uint16: np.uint16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.uint8, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, + } + # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 @@ -10089,8 +10123,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor: @classmethod def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[tensor.dtype] - return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) + numpy_dtype = cls._dtype_byteswap_map[dtype] + return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) @@ -10098,10 +10138,16 @@ def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: @classmethod def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[remote_tensor.dtype] + numpy_dtype = cls._dtype_byteswap_map[dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 57c6cd0df1d..b0adde8a8b4 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( diff --git a/examples/batched/README.md b/examples/batched/README.md index 6013aab01fd..8cde35dd644 100644 --- a/examples/batched/README.md +++ b/examples/batched/README.md @@ -3,7 +3,7 @@ The example demonstrates batched generation from a given prompt ```bash -./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 +./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified ... diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index 26de5668aa8..f71d2413193 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -6,8 +6,54 @@ More Info: - https://github.com/ggml-org/llama.cpp/pull/14644 - https://github.com/ggml-org/llama.cpp/pull/14771 +## Parameters +The diffusion CLI supports various parameters to control the generation process: -Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual` +### Core Diffusion Parameters +- `--diffusion-steps`: Number of diffusion steps (default: 256) +- `--diffusion-algorithm`: Algorithm for token selection + - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: ENTROPY_BASED - Entropy-based selection + - `2`: MARGIN_BASED - Margin-based selection + - `3`: RANDOM - Random selection + - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - More documentation here https://github.com/DreamLM/Dream +- `--diffusion-visual`: Enable live visualization during generation -Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual` +### Scheduling Parameters +Choose one of the following scheduling methods: +**Timestep-based scheduling:** +- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001) + +**Block-based scheduling:** +- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32) + +### Sampling Parameters +- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random) +- `--top-k`: Top-k filtering for sampling +- `--top-p`: Top-p (nucleus) filtering for sampling +- `--seed`: Random seed for reproducibility + +### Model Parameters +- `-m`: Path to the GGUF model file +- `-p`: Input prompt text +- `-ub`: Maximum sequence length (ubatch size) +- `-c`: Context size +- `-b`: Batch size + +### Examples +#### Dream architechture: +``` +llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual +``` + +#### LLaDA architechture: +``` +llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual +``` + +#### RND1 architecture: +``` +llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 +``` diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 869796f0e3b..0211255a762 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -25,16 +25,17 @@ if(GIT_EXE) ) endif() -# Build the version string with optional dirty flag set(GGML_VERSION "${GGML_VERSION_BASE}") -if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) - set(GGML_VERSION "${GGML_VERSION}-dirty") -endif() if(NOT GGML_BUILD_COMMIT) set(GGML_BUILD_COMMIT "unknown") endif() +# Build the commit string with optional dirty flag +if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1) + set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty") +endif() + include(CheckIncludeFileCXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c2..4dbca868bc7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -530,6 +530,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_TOP_K, GGML_OP_LEAKY_RELU, GGML_OP_TRI, GGML_OP_FILL, @@ -2258,18 +2259,25 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); - GGML_API struct ggml_tensor * ggml_arange( + // similar to ggml_top_k but implemented as `argsort` + `view` + GGML_API struct ggml_tensor * ggml_argsort_top_k( struct ggml_context * ctx, - float start, - float stop, - float step); + struct ggml_tensor * a, + int k); // top k elements per row + // note: the resulting top k indices are in no particular order GGML_API struct ggml_tensor * ggml_top_k( struct ggml_context * ctx, struct ggml_tensor * a, int k); + GGML_API struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step); + #define GGML_KQ_MASK_PAD 64 // q: [n_embd_k, n_batch, n_head, ne3 ] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 628db3fd655..a4499509ece 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -328,6 +328,14 @@ function(ggml_add_cpu_backend_variant tag_name) set(GGML_INTERNAL_${feat} OFF) endforeach() + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + foreach (feat RVV) + set(GGML_INTERNAL_${feat} OFF) + endforeach() + foreach (feat ${ARGN}) set(GGML_INTERNAL_${feat} ON) endforeach() @@ -402,6 +410,13 @@ if (GGML_CPU_ALL_VARIANTS) else() message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}") endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + ggml_add_cpu_backend_variant(riscv64_0) + ggml_add_cpu_backend_variant(riscv64_v RVV) + else() + message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}") + endif() else() message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") endif() diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 606c6d1783d..08cf2b118b2 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -3236,3 +3237,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ABORT("Function is not implemented."); } } + +static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // weight + ggml_tensor * src1 = dst->src[1]; // input + GGML_TENSOR_BINARY_OP_LOCALS + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + const int64_t i12 = i2; + const int64_t i13 = i3; + acl_tensor_ptr accumulator = + ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + // The outer product needs to be accumulated in this dimension. + for (int64_t i1 = 0; i1 < ne11; i1++) { + acl_tensor_ptr acl_input = ggml_cann_create_tensor( + (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src1->ne, src1->nb, 1); + + acl_tensor_ptr acl_weight = ggml_cann_create_tensor( + (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, 1); + + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); + acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); + float alpha_value = 1.0f; + aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); + } + } + } +} + +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + const enum ggml_type type = src0->type; + + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_out_prod_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6c2eb1226b..1ebbc769c71 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::functionsrc[0]` and `dst->src[1]`. + * + * @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation + */ +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 5cbf5683e1d..8995a5c121f 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cann_out_prod(ctx, dst); + break; default: return false; } @@ -2303,9 +2306,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, // calculate rope cache for fist layer in current device. cann_ctx->rope_cache.cached = false; + bool cann_graph_update_required = false; #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - bool cann_graph_update_required = false; static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { @@ -2336,7 +2339,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, } #else bool use_cann_graph = false; - bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required); @@ -2564,6 +2566,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_OUT_PROD: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + } case GGML_OP_CONV_TRANSPOSE_1D: // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. return (op->src[0]->ne[0] - 1) <= 255; diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index d0cab0bcb9c..feb56173861 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -452,22 +452,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/spacemit/ime_kernels.h ) endif() - set(MARCH_STR "rv64gc") - if (GGML_RV_ZFH) - string(APPEND MARCH_STR "_zfh") - endif() - if (GGML_XTHEADVECTOR) - string(APPEND MARCH_STR "_xtheadvector") - elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") - if (GGML_RV_ZVFH) - string(APPEND MARCH_STR "_zvfh") + if(NOT GGML_CPU_ALL_VARIANTS) + set(MARCH_STR "rv64gc") + if (GGML_RV_ZFH) + string(APPEND MARCH_STR "_zfh") endif() + if (GGML_XTHEADVECTOR) + string(APPEND MARCH_STR "_xtheadvector") + elseif (GGML_RVV) + string(APPEND MARCH_STR "_v") + if (GGML_RV_ZVFH) + string(APPEND MARCH_STR "_zvfh") + endif() + endif() + if (GGML_RV_ZICBOP) + string(APPEND MARCH_STR "_zicbop") + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) + else() + # Begin with the lowest baseline + set(ARCH_DEFINITIONS "") + + if (GGML_INTERNAL_RVV) + message(STATUS "RVV enabled") + list(APPEND ARCH_DEFINITIONS GGML_USE_RVV) + list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d) + endif() + + ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS}) endif() - if (GGML_RV_ZICBOP) - string(APPEND MARCH_STR "_zicbop") - endif() - list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") list(APPEND GGML_CPU_SOURCES diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index edfd7913903..d27a9697060 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -51,10 +51,8 @@ #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fdd0a513b83..d2adfbea873 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -24,6 +24,29 @@ #define UNUSED GGML_UNUSED +static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, + int16x8_t * out_mins, + int8_t * out_scales) { + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 }; + + *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32))); + + uint32_t scales_u32[2]; + scales_u32[0] = sm[0] & kmask1; + scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + memcpy(out_scales, scales_u32, 8); +} + void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -474,6 +497,162 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d); + float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3 + float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q4sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns + // but still need the qs to use the low and hi bits from q4 + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp); + uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64); + uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128); + uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192); + + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31 + + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63 + } + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + + // 0123 or 4567 + // TODO: Single superblock mul at the end of the superblock + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + } + + // Multiply Acc bsum + mins + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // cols 0-3 bias + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // cols 4-7 bias + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1889,3 +2068,212 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_q4_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q4sb_scales[2][8]; + int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], + q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], + q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] }, + }; + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + const int8x16_t q4_nibbles[2][4] = { + { + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), + }, + { + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), + } + }; + + // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 + // for each of the internal 32 qs subblock (blk) + for (int rp = 0; rp < 2; rp++) { + for (int blk = 0; blk < 2; blk++) { + const int8x16_t * q8 = &q8s[rp][4 * blk]; + const int8x16_t * q4 = q4_nibbles[blk]; + int32x4_t acc = sb_acc[2 * rp + blk]; + // mul add for each qs in the same subblock + for (int qs_offset = 0; qs_offset < 4; qs_offset++) { + acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]); + } + sb_acc[2 * rp + blk] = acc; + } + } + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + for (int blk = 0; blk < 2; blk++) { + const int32x4_t block_scale = { + (int32_t) q4sb_scales[blk][scale_offset], + (int32_t) q4sb_scales[blk][scale_offset], + (int32_t) q4sb_scales[blk][scale_offset + 1], + (int32_t) q4sb_scales[blk][scale_offset + 1], + }; + acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); + } + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d); + + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q4_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp new file mode 100644 index 00000000000..b1818988185 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp @@ -0,0 +1,35 @@ +#include "ggml-backend-impl.h" + +#if defined(__riscv) && __riscv_xlen == 64 +#include + +//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24 +#ifndef COMPAT_HWCAP_ISA_V +#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) +#endif + +struct riscv64_features { + bool has_rvv = false; + + riscv64_features() { + uint32_t hwcap = getauxval(AT_HWCAP); + + has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V); + } +}; + +static int ggml_backend_cpu_riscv64_score() { + int score = 1; + riscv64_features rf; + +#ifdef GGML_USE_RVV + if (!rf.has_rvv) { return 0; } + score += 1 << 1; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score) + +#endif // __riscv && __riscv_xlen == 64 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c1..3247af8bb03 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_TOP_K: + { + ggml_compute_forward_top_k(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; + case GGML_OP_TOP_K: + { + cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks; + } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne10 = node->src[1]->ne[0]; // DK diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 41e89d83c2e..d405696539e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort template -struct argsort_cmp { +struct cmp_argsort { const float * data; bool operator()(int32_t a, int32_t b) const { if constexpr (order == GGML_SORT_ORDER_ASC) { @@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32( switch (order) { case GGML_SORT_ORDER_ASC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; case GGML_SORT_ORDER_DESC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; default: @@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort( } } +// ggml_compute_forward_top_k + +struct cmp_top_k { + const float * data; + bool operator()(int32_t a, int32_t b) const { + return data[a] > data[b]; + } +}; + +static void ggml_compute_forward_top_k_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + const int top_k = ne0; + + int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t i = ith; i < nr; i += nth) { + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne00; j++) { + tmp[j] = j; + } + + std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data}); + + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + + std::copy(tmp, tmp + top_k, dst_data); + + // emphasize that the order is not important + if (top_k > 1) { + std::swap(dst_data[0], dst_data[1]); + } + } +} + +void ggml_compute_forward_top_k( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_top_k_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 98a0eec16d9..0fdfee79766 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 3db26cff74b..d1321191358 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1961,6 +1961,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 25e9308d756..99ec96869a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(RDNA4) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(RDNA4) + // The Volta instructions are in principle available on Turing or newer but they are effectively unusable: #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #define VOLTA_MMA_AVAILABLE @@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) { #endif //!defined(GGML_HIP_NO_MMQ_MFMA) } +static bool amd_wmma_available(const int cc) { + return GGML_CUDA_CC_IS_RDNA4(cc); +} + static bool volta_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 8a5e08ef667..09f9a33f909 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -39,6 +39,15 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIP + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index e621cb9811a..7697c292dd6 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { +static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c1afde9627f..82367ad3fb0 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13) { @@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) { +static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in } template -static void ggml_cpy_flt_contiguous_cuda( +static void ggml_cpy_scalar_contiguous_cuda( const char * cx, char * cdst, const int64_t ne, cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt_contiguous<<>> + cpy_scalar_contiguous<<>> (cx, cdst, ne); } template -static void ggml_cpy_flt_cuda( +static void ggml_cpy_scalar_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { @@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda( (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_flt_transpose<<>> + cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> + cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } @@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q8_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8fe0899bb5a..0b29074f33d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4115,6 +4115,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a7a28fd1ae6..caa08b360b5 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,7 +73,7 @@ namespace ggml_cuda_mma { static constexpr int I = I_; static constexpr int J = J_; -#if defined(GGML_USE_HIP) +#if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -149,6 +149,34 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -236,6 +264,32 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -285,6 +339,34 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; + +#if defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -320,6 +402,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(AMD_WMMA_AVAILABLE) }; template @@ -353,6 +436,21 @@ namespace ggml_cuda_mma { const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; } +#elif defined(AMD_WMMA_AVAILABLE) + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + }else{ + NO_DEVICE_CODE; + } #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -639,12 +737,34 @@ namespace ggml_cuda_mma { : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { #if defined(AMD_MFMA_AVAILABLE) @@ -665,6 +785,36 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // defined(RDNA4) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -691,6 +841,7 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -735,4 +886,31 @@ namespace ggml_cuda_mma { mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } + +static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif + } } + diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 153dd5a97d5..5c51a22256a 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16) { + if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { return false; } } @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 45724e0911e..c2a0a2e42fe 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -2,6 +2,7 @@ #include "mma.cuh" #include "common.cuh" +#include "convert.cuh" using namespace ggml_cuda_mma; @@ -27,20 +28,35 @@ static __global__ void mul_mat_f( const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } + constexpr bool supported = I_16_supported || I_32_supported; constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -161,11 +177,11 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } else { const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } } } else { @@ -239,7 +255,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids( #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const float2 tmp = vals_buf[curr_buf][j0]; - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } if (itB + 1 < ntB) { @@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } template @@ -554,7 +585,8 @@ void mul_mat_f_cuda( cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; - typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, T> tile_B_16; + typedef tile< 8, 8, T> tile_B_8; GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); @@ -581,7 +613,8 @@ void mul_mat_f_cuda( constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a2c8760abea..03ceba874d8 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return true; + } + } + + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 2e133b6bda8..99760d56c72 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -92,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -102,7 +102,7 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -231,7 +231,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - if (amd_mfma_available(cc)) { + if (amd_mfma_available(cc) || amd_wmma_available(cc)) { return mmq_x >= 128 ? 32 : 16; } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; @@ -240,7 +240,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { } } -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } @@ -265,7 +265,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 8; #else return 256/ggml_cuda_get_physical_warp_size(); @@ -279,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -305,7 +305,7 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else @@ -327,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -382,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -408,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -430,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -485,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -527,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -550,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -563,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -603,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -626,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -639,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -665,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -688,11 +688,11 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -701,14 +701,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; @@ -730,13 +730,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; @@ -753,11 +753,11 @@ template static __device__ __forceinline__ void loa const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -796,7 +796,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -927,7 +927,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -965,7 +965,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -1087,7 +1087,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } // Used for Q3_K, IQ2_S, and IQ2_XS @@ -1170,6 +1170,54 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1257,21 +1305,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1295,11 +1343,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1310,11 +1358,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1438,6 +1486,72 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cd; mma(Cd, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1574,7 +1688,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1582,7 +1696,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1618,11 +1732,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1649,7 +1763,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1659,10 +1773,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1675,7 +1789,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE) } template @@ -1728,7 +1842,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1736,7 +1850,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1753,19 +1867,19 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' @@ -1774,7 +1888,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1829,7 +1943,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -1872,7 +1986,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1908,16 +2022,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1930,7 +2044,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1986,7 +2100,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -2029,7 +2143,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -2038,7 +2152,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -2065,13 +2179,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } #pragma unroll @@ -2084,11 +2198,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2102,11 +2216,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2190,6 +2304,56 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -2303,7 +2467,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( @@ -2311,14 +2475,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2340,13 +2504,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2363,11 +2527,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2376,14 +2540,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2414,22 +2578,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2438,14 +2602,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2472,24 +2636,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2498,15 +2662,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; @@ -2539,24 +2702,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2565,14 +2728,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2601,22 +2764,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2625,14 +2788,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2668,22 +2831,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2692,14 +2855,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2727,23 +2890,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2752,14 +2915,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2779,13 +2942,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2804,11 +2967,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2848,7 +3011,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); typedef tile tile_C; constexpr int rows_per_warp = granularity; @@ -2859,11 +3022,11 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #else GGML_UNUSED(nwarps); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -3063,13 +3226,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index 166825c2c5f..ac422027b91 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -43,6 +43,14 @@ set(HTP_CMAKE_ARGS -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}) +ExternalProject_Add(htp-v68 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") + +ExternalProject_Add(htp-v69 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69") + ExternalProject_Add(htp-v73 SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") @@ -61,6 +69,8 @@ ExternalProject_Add(htp-v81 # Install Hexagon skels required at runtime install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cabd301ad35..72a82a89116 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #ifdef _WIN32 # include @@ -240,6 +241,23 @@ struct ggml_hexagon_session { uint32_t prof_pkts; }; +static inline void hex_print_op_info(const ggml_tensor * op, ggml_hexagon_session * sess, const uint32_t req_flags) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req_flags); +} + void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { // Bump pending flag (cleared in the session::flush once we get the responce) this->op_pending++; // atomic inc @@ -1912,6 +1930,15 @@ static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_t return true; } +template +static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess, _TTensor... tensors) { + return ([&]() -> bool { + return !tensors || !tensors->buffer || + (ggml_backend_buffer_is_hexagon(tensors->buffer) && + ggml_backend_hexagon_buffer_get_sess(tensors->buffer) == sess); + }() && ...); +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -1959,16 +1986,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s } // src0 & src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2016,20 +2034,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session // src0 (weights) must be repacked and mapped to the same session // src1 & sr2 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2063,16 +2068,7 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2104,20 +2100,7 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2144,12 +2127,7 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses } // src0 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, dst)) { return false; } @@ -2186,16 +2164,7 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1 && src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2248,16 +2217,7 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1 && src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2269,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { return false; } if (mode & 1) { @@ -2312,20 +2272,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess } // src0, src1, src2 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2 && src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2346,6 +2293,26 @@ static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) { h->nb[3] = t->nb[3]; } +static size_t dspqueue_buffers_init(dspqueue_buffer * buf, const ggml_tensor * t, bool flush_host, bool flush_htp) { + if (!t) { + return 0; + } + + memset(buf, 0, sizeof(*buf)); + auto tensor_buf = static_cast(t->buffer->context); + buf->fd = tensor_buf->fd; + buf->ptr = t->data; + buf->offset = (uint8_t *) t->data - tensor_buf->base; + buf->size = ggml_nbytes(t); + buf->flags = (flush_host ? DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER : 0); // Flush CPU + buf->flags |= (flush_htp ? DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT : 0); // Invalidate DSP + return 1; +} + +static ggml_hexagon_session * get_session_from_tensor(const ggml_tensor * t) { + return static_cast(t->buffer->context)->sess; +} + static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) { auto buf = static_cast(t->buffer->context); auto sess = buf->sess; @@ -2360,10 +2327,6 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1, t2; t1 = ggml_time_us(); @@ -2385,55 +2348,27 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) } dspqueue_buffer bufs[3]; - memset(bufs, 0, sizeof(bufs)); // First buffer Weights. // The content is static, there is no need to do any cache management - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = 0; + dspqueue_buffers_init(bufs, src0, false, false); // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[2].fd = dst_buf->fd; - bufs[2].ptr = dst->data; - bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[2], dst, true, false); - // Primary DSP session from the src0 (normally weight) tensor - auto sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2463,11 +2398,6 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto src2_buf = static_cast(src2->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1, t2; t1 = ggml_time_us(); @@ -2490,66 +2420,32 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag } dspqueue_buffer bufs[4]; - memset(bufs, 0, sizeof(bufs)); - // First buffer Weights. // The content is static, there is no need to do any cache management - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = 0; + dspqueue_buffers_init(bufs, src0, false, false); // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer expert IDs. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[2].fd = src2_buf->fd; - bufs[2].ptr = src2->data; - bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[2], src2, true, true); // Forth buffer Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[3].fd = dst_buf->fd; - bufs[3].ptr = dst->data; - bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[3], dst, true, false); - // Primary DSP session from the src0 (normally weight) tensor - auto sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); - + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2581,10 +2477,6 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * dst = node; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1 = 0; uint64_t t2 = 0; @@ -2621,60 +2513,30 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { init_htp_tensor(&req.dst, dst); dspqueue_buffer bufs[3]; - memset(bufs, 0, sizeof(bufs)); - // First buffer = First Operand of Binary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + dspqueue_buffers_init(bufs, src0, true, true); // Second buffer = Second Operand of Binary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[2].fd = dst_buf->fd; - bufs[2].ptr = dst->data; - bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[2], dst, true, false); - // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[16 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), - ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2705,11 +2567,6 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { const struct ggml_tensor * src2 = node->src[2]; const struct ggml_tensor * dst = node; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto src2_buf = static_cast(src2->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1 = 0; uint64_t t2 = 0; @@ -2741,58 +2598,19 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { init_htp_tensor(&req.dst, dst); dspqueue_buffer bufs[4]; - memset(bufs, 0, sizeof(bufs)); - // First buffer = input activations - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - + dspqueue_buffers_init(bufs, src0, true, true); // Second buffer = experts bias - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer = activated experts - bufs[2].fd = src2_buf->fd; - bufs[2].ptr = src2->data; - bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - + dspqueue_buffers_init(&bufs[2], src2, true, true); // Forth buffer = output activations - bufs[3].fd = dst_buf->fd; - bufs[3].ptr = dst->data; - bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[3], dst, true, true); - // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[16 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), - ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); - + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2886,71 +2704,33 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { } dspqueue_buffer bufs[3]; - int n_bufs = 0; - - memset(bufs, 0, sizeof(bufs)); // First buffer = Only Operand of Unary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src0_buf = static_cast(src0->buffer->context); - bufs[n_bufs].fd = src0_buf->fd; - bufs[n_bufs].ptr = src0->data; - bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - ++n_bufs; + size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); - if (src1) { - // Second buffer = Second Operand of Binary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - auto src1_buf = static_cast(src1->buffer->context); - bufs[n_bufs].fd = src1_buf->fd; - bufs[n_bufs].ptr = src1->data; - bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; - } + // Second buffer(nullable) = Second Operand of Binary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); // Second or third buffer = Output Activations. We'll handle DSP // Second buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - auto dst_buf = static_cast(dst->buffer->context); - bufs[n_bufs].fd = dst_buf->fd; - bufs[n_bufs].ptr = dst->data; - bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); if (src1) { @@ -3023,85 +2803,40 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { } dspqueue_buffer bufs[4]; - int n_bufs = 0; - - memset(bufs, 0, sizeof(bufs)); // First buffer // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src0_buf = static_cast(src0->buffer->context); - bufs[n_bufs].fd = src0_buf->fd; - bufs[n_bufs].ptr = src0->data; - bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - ++n_bufs; + size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); // Second buffer // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src1_buf = static_cast(src1->buffer->context); - bufs[n_bufs].fd = src1_buf->fd; - bufs[n_bufs].ptr = src1->data; - bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); - if (src2) { - // Third buffer - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - auto src2_buf = static_cast(src2->buffer->context); - bufs[n_bufs].fd = src2_buf->fd; - bufs[n_bufs].ptr = src2->data; - bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[n_bufs].size = ggml_nbytes(src2); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; - } + // Third buffer(nullable) + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src2, true, true); // Final buffer = Output Activations. We'll handle DSP // Second buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - auto dst_buf = static_cast(dst->buffer->context); - bufs[n_bufs].fd = dst_buf->fd; - bufs[n_bufs].ptr = dst->data; - bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); if (src1) { diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c index e8a035af8c6..3f335bf71c0 100644 --- a/ggml/src/ggml-hexagon/htp-utils.c +++ b/ggml/src/ggml-hexagon/htp-utils.c @@ -390,6 +390,12 @@ int get_hex_arch_ver(int domain, int * arch) { } switch (arch_ver.capability & 0xff) { + case 0x68: + *arch = 68; + return 0; + case 0x69: + *arch = 69; + return 0; case 0x73: *arch = 73; return 0; diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/htp-dma.h index 4d0d54ce859..7d3fc4078cc 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.h +++ b/ggml/src/ggml-hexagon/htp/htp-dma.h @@ -66,6 +66,13 @@ static inline bool dma_queue_push(dma_queue * q, desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; desc->dstbypass = 1; desc->srcbypass = 1; +#if __HVX_ARCH__ >= 73 + desc->dstbypass = 1; + desc->srcbypass = 1; +#else + desc->dstbypass = 0; + desc->srcbypass = 1; +#endif desc->order = 0; desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; desc->src = (void *) src; diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c index d0735e9325e..21bf46a542f 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.c +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.c @@ -16,13 +16,8 @@ #include "hvx-utils.h" #include "ops-utils.h" -static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) { - static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) - - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); +static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); HVX_Vector out = hvx_vec_exp_fp32(in_vec); @@ -47,6 +42,12 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int HVX_Vector vec_out = Q6_V_vzero(); + static const float kInf = INFINITY; + static const float kMaxExp = 88.02f; // log(INF) + + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_fp32(kInf); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in1 = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; @@ -55,9 +56,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); - *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in); + *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf); } } } else { @@ -67,9 +68,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf); } } } @@ -83,9 +84,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - vec_out = hvx_vec_exp_fp32_guard(neg_vec_in); + vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - vec_out = hvx_vec_exp_fp32_guard(in); + vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf); } hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c index 953d3e6c167..4d70634fcd4 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.c @@ -16,6 +16,15 @@ #include "hvx-utils.h" #include "ops-utils.h" +static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_fp32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -32,19 +41,22 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); } + static const uint32_t kNanInfMask = 0x7f800000; + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++); + *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask); } } else { #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); } } @@ -53,7 +65,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const float * dstf = (float *) dst + num_elems_whole; HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_inverse_fp32_guard(in); + HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 5f94645cde3..80658105c55 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -21,6 +21,26 @@ typedef union { float fp32[VLEN_FP32]; } __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; +/* Q6_Vsf_equals_Vw is only available on v73+.*/ +#if __HVX_ARCH__ < 73 +static inline HVX_Vector int32_to_qfloat(HVX_Vector const in) +{ + HVX_Vector const vzero = Q6_V_vzero(); + HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); + HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); + HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); + HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); + HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); + HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); + return ret; +} + +static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) +{ + return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in)); +} +#endif + static inline HVX_Vector hvx_vec_splat_fp32(float i) { union { float f; @@ -726,24 +746,6 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) { return Q6_Vsf_equals_Vqf32(r_qf); } -static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) { - static const float kInf = INFINITY; - static const uint32_t kNanMask = 0x7fffffff; - static const uint32_t kNanMin = 0x7f800000; - - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf); - - HVX_Vector out = hvx_vec_inverse_fp32(v_sf); - - const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask); - const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin); - HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask); - const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out); - - return Q6_V_vmux_QVV(pred, out, Q6_V_vzero()); -} - #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 #define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 @@ -958,14 +960,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) { - static const float kMaxExp = -88.02f; // log(INF) - - const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp)); - const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp); +static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v); - return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero()); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { @@ -977,9 +981,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + + const HVX_Vector one = hvx_vec_splat_fp32(1.f); + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { - v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]); + v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); } } diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 10e27333243..b60b352a7b4 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -143,16 +143,25 @@ AEEResult htp_iface_disable_etm(remote_handle64 handle) { } static int vtcm_acquire(struct htp_context * ctx) { + int err; if (!ctx->vtcm_valid) { // Temporarily bump thread priority to make sure it's higher than other sessions. // This way the resource manager will notify the other thread to release VTCM. // Note that we need to reaquire VTCM at normal priority for this to work next time. qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } HAP_compute_res_release_cached(ctx->vtcm_rctx); qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); - HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } ctx->vtcm_valid = true; } @@ -201,7 +210,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 16afa50f5b0..00419bcba6b 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -24,6 +24,10 @@ #include "hvx-utils.h" #include "ops-utils.h" +// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h +#define HTP_ROPE_TYPE_NORMAL 0 +#define HTP_ROPE_TYPE_NEOX 2 + #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); } +static void hvx_calc_rope_neox_f32(const float * restrict src0, + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { + // for (int i = 0; i < num_elems; i += 2) { + //const float cos_theta = theta_cache[i + 0]; + //const float sin_theta = theta_cache[i + 1]; + + //const float x0 = src[0]; + //const float x1 = src[num_elems/2]; + + //dst[0] = x0*cos_theta - x1*sin_theta; + //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + + //src += 1; + //dst += 1; + // } + + const uint8_t * restrict src0_curr = (const uint8_t *) src0; + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; + uint8_t * restrict dst_curr = (uint8_t *) dst; + + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once + int half_size = (sizeof(float) * (num_elems / 2)); + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v0 = *(HVX_Vector *) src0_curr; + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); + + HVX_Vector v2 = *(HVX_Vector *) theta_curr; + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); + *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + + src0_curr += VLEN; + theta_curr += 2 * VLEN; + dst_curr += VLEN; + } +} + static void hvx_calc_rope_f32(const float * restrict src0, float * restrict dst, const int num_elems, @@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; + const int32_t mode = rope_ctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + htp_rope_preamble; const int32_t * pos = (const int32_t *) src1->data; @@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, float * dst_data_loc = dst_data; if (1 == opt_path) { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + if (is_neox) { + hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } else { + hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } } else { for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { const float cos_theta = wp0[i0 + 0]; const float sin_theta = wp0[i0 + 1]; - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; + if (is_neox) { + const float x0 = src_loc[0]; + const float x1 = src_loc[rope_ctx->n_dims/2]; + + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; + + src_loc += 1; + dst_data_loc += 1; + } else { + const float x0 = src_loc[0]; + const float x1 = src_loc[1]; - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; - src_loc += 2; - dst_data_loc += 2; + src_loc += 2; + dst_data_loc += 2; + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b137b..329500a03e0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l return res; } +// note: reuse the argsort kernel for top_k +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + // note: the top_k kernel is always descending order + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 39ee6e3427e..3976e622b9b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index acf9dfd5fc1..09b1b503118 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 0fae97029f1..342dc4f8c37 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,14 +832,19 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; } ggml_metal_kargs_argsort; typedef struct { @@ -851,6 +856,11 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; int32_t len; } ggml_metal_kargs_argsort_merge; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e46970dad47..9871e976f23 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argsort(ctx, idx); } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; case GGML_OP_LEAKY_RELU: { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); @@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_argsort args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, }; ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); ggml_metal_kargs_argsort_merge args_merge = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .len = len, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, }; // merges per row @@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 332e550ee70..b5546146e13 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index f6033ddc97b..70bf6f3d981 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ { res *= 2; } break; + case GGML_OP_TOP_K: + { + res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59e5761704e..73b45c762d9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32( ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort const int col = tpitg[0]; + const int ib = tgpig[0] / args.ne01; - const int i00 = (tgpig[0]/args.ne01)*ntg.x; - const int i01 = tgpig[0]%args.ne01; - const int i02 = tgpig[1]; - const int i03 = tgpig[2]; + const int i00 = ib*ntg.x; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); @@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32( } } + const int64_t i0 = ib*args.top_k; + // copy the result to dst without the padding - if (i00 + col < args.ne00) { - dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + if (i0 + col < args.ne0 && col < args.top_k) { + dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; dst[col] = shmem_i32[col]; } @@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32( const int start = im * (2 * args.len); - const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); - const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); const int total = len0 + len1; device const int32_t * tmp0 = tmp + start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; device const int32_t * tmp1 = tmp0 + args.len; dst += start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; device const float * src0_row = (device const float *)(src0 + args.nb01*i01 @@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32( const int chunk = (total + ntg.x - 1) / ntg.x; const int k0 = tpitg.x * chunk; - const int k1 = min(k0 + chunk, total); + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } if (k0 >= total) { return; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4cb6afe9271..2319f7a9e25 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_context context = backend_ctx->context; if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){ - ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); - return; + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + // For KQ + if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + nb00 <= nb02 && + nb02 <= nb01 && + nb01 <= nb03 && + nb10 <= nb12 && + nb12 <= nb11 && + nb11 <= nb13) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } + // For KQV + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d4f27af8fc6..6cf15b43bb3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1629,6 +1629,22 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + const ggml_tensor * dst = node; + const ggml_tensor * q = node->src[0]; + const ggml_tensor * k = node->src[1]; + const ggml_tensor * v = node->src[2]; + const ggml_tensor * m = node->src[3]; + std::stringstream name; + name << ggml_op_name(node->op) << + " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " << + " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " << + " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << + " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << + " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -2485,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2519,9 +2537,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -7724,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7855,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -11381,13 +11399,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready); +static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){ ggml_tensor * node = cgraph->nodes[node_idx]; - if (ggml_is_empty(node) || !node->buffer) { + if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } @@ -11399,132 +11417,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_tensor * src2 = node->src[2]; ggml_tensor * src3 = node->src[3]; - switch (node->op) { - // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: - return false; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SOFTPLUS: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_ROUND: - case GGML_UNARY_OP_CEIL: - case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_TRUNC: - break; - default: - return false; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - break; - default: - return false; - } - break; - case GGML_OP_ADD: - { - int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; - if (next_node_idx < cgraph->n_nodes && - cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && - cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && - ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && - ctx->device->add_rms_fusion) { - uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); - ctx->do_add_rms_partials_offset_calculation = true; - if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) { - ctx->do_add_rms_partials = true; - } + if (node->op == GGML_OP_ADD) { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + ctx->do_add_rms_partials_offset_calculation = true; + if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) { + ctx->do_add_rms_partials = true; } - } break; - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_GET_ROWS: - case GGML_OP_ADD_ID: - case GGML_OP_ACC: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD1: - case GGML_OP_ARANGE: - case GGML_OP_FILL: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_ROLL: - case GGML_OP_CPY: - case GGML_OP_SET_ROWS: - case GGML_OP_CONT: - case GGML_OP_DUP: - case GGML_OP_SILU_BACK: - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_L2_NORM: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_ARGSORT: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_CONV_TRANSPOSE_1D: - case GGML_OP_POOL_2D: - case GGML_OP_CONV_2D: - case GGML_OP_CONV_TRANSPOSE_2D: - case GGML_OP_CONV_2D_DW: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: - case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: - case GGML_OP_OPT_STEP_ADAMW: - case GGML_OP_OPT_STEP_SGD: - break; - default: - std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; - GGML_ABORT("fatal error"); + } } vk_context compute_ctx; @@ -11961,145 +11866,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready); - if (!ok) { - if (node->op == GGML_OP_UNARY) { - std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } else if (node->op == GGML_OP_GLU) { - std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } else { - std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; - } - } - + ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready); } return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) { +static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) { GGML_UNUSED(cgraph); - ggml_backend_buffer * buf = nullptr; - - switch (tensor->op) { - case GGML_OP_ADD: - case GGML_OP_ACC: - case GGML_OP_GET_ROWS: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD1: - case GGML_OP_ARANGE: - case GGML_OP_FILL: - case GGML_OP_ADD_ID: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_ROLL: - case GGML_OP_CPY: - case GGML_OP_SET_ROWS: - case GGML_OP_CONT: - case GGML_OP_DUP: - case GGML_OP_SILU_BACK: - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_L2_NORM: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: - case GGML_OP_ARGSORT: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_CONV_TRANSPOSE_1D: - case GGML_OP_POOL_2D: - case GGML_OP_CONV_2D: - case GGML_OP_CONV_TRANSPOSE_2D: - case GGML_OP_CONV_2D_DW: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: - case GGML_OP_LEAKY_RELU: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_OPT_STEP_ADAMW: - case GGML_OP_OPT_STEP_SGD: - buf = tensor->buffer; - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SOFTPLUS: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_ROUND: - case GGML_UNARY_OP_CEIL: - case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_TRUNC: - buf = tensor->buffer; - break; - default: - return false; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(tensor)) { - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - buf = tensor->buffer; - break; - default: - return false; - } - break; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_FLASH_ATTN_EXT: - buf = tensor->buffer; - - break; - default: - return false; - } - - if (buf == nullptr) { - return false; - } + GGML_UNUSED(tensor); VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); @@ -12143,8 +11917,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * subctx->out_memcpys.clear(); subctx->memsets.clear(); } - - return true; } // Clean up after graph processing is done diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a23937..b99345a2e93 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "TOP_K", "LEAKY_RELU", "TRI", "FILL", @@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "top_k(x)", "leaky_relu(x)", "tri(x)", "fill(x, c)", @@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll( return result; } -// ggml_arange - -struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - - ggml_set_op_params_f32(result, 0, start); - ggml_set_op_params_f32(result, 1, stop); - ggml_set_op_params_f32(result, 2, step); - - result->op = GGML_OP_ARANGE; - - return result; -} - // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( @@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_tensor * a, enum ggml_sort_order order) { GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5149,9 +5130,9 @@ struct ggml_tensor * ggml_argsort( return result; } -// ggml_top_k +// ggml_argsort_top_k -struct ggml_tensor * ggml_top_k( +struct ggml_tensor * ggml_argsort_top_k( struct ggml_context * ctx, struct ggml_tensor * a, int k) { @@ -5167,6 +5148,44 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k + +struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]); + + result->op = GGML_OP_TOP_K; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cd0efad4a8..6f5a742e04a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -25,6 +25,20 @@ class General: ALIGNMENT = "general.alignment" FILE_TYPE = "general.file_type" + # Recommended Sampler Parameters + SAMPLING_SEQUENCE = "general.sampling.sequence" + SAMPLING_TOP_K = "general.sampling.top_k" + SAMPLING_TOP_P = "general.sampling.top_p" + SAMPLING_MIN_P = "general.sampling.min_p" + SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability" + SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold" + SAMPLING_TEMP = "general.sampling.temp" + SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n" + SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat" + SAMPLING_MIROSTAT = "general.sampling.mirostat" + SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau" + SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta" + # Authorship Metadata NAME = "general.name" AUTHOR = "general.author" @@ -427,6 +441,7 @@ class MODEL_ARCH(IntEnum): APERTUS = auto() COGVLM = auto() MINIMAXM2 = auto() + RND1 = auto() PANGU_EMBED = auto() @@ -797,6 +812,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.APERTUS: "apertus", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", + MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", } @@ -2991,6 +3007,23 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.VISEXP_UP, MODEL_TENSOR.VISEXP_DOWN, ], + MODEL_ARCH.RND1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PANGU_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb13..57ca2035fe2 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -4,6 +4,7 @@ import os import shutil import struct +import sys import tempfile from dataclasses import dataclass from enum import Enum, auto @@ -372,8 +373,10 @@ def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) @@ -399,8 +402,10 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) file_id = -1 for i, tensors in enumerate(self.tensors): @@ -496,6 +501,42 @@ def add_custom_alignment(self, alignment: int) -> None: def add_file_type(self, ftype: int) -> None: self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_sampling_sequence(self, sequence: str) -> None: + self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence) + + def add_sampling_top_k(self, top_k: int) -> None: + self.add_int32(Keys.General.SAMPLING_TOP_K, top_k) + + def add_sampling_top_p(self, top_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_TOP_P, top_p) + + def add_sampling_min_p(self, min_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIN_P, min_p) + + def add_sampling_xtc_probability(self, xtc_probability: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability) + + def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold) + + def add_sampling_temp(self, temp: float) -> None: + self.add_float32(Keys.General.SAMPLING_TEMP, temp) + + def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None: + self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n) + + def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None: + self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat) + + def add_sampling_mirostat(self, mirostat: int) -> None: + self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat) + + def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau) + + def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta) + def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 67efedbdbc5..e0d478ce95d 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -17,6 +17,20 @@ @dataclass class Metadata: + # Recommended Sampler Parameters to be written to GGUF KV Store + sampling_sequence: Optional[str] = None + sampling_top_k: Optional[int] = None + sampling_top_p: Optional[float] = None + sampling_min_p: Optional[float] = None + sampling_xtc_probability: Optional[float] = None + sampling_xtc_threshold: Optional[float] = None + sampling_temp: Optional[float] = None + sampling_penalty_last_n: Optional[int] = None + sampling_penalty_repeat: Optional[float] = None + sampling_mirostat: Optional[int] = None + sampling_mirostat_tau: Optional[float] = None + sampling_mirostat_eta: Optional[float] = None + # Authorship Metadata to be written to GGUF KV Store name: Optional[str] = None author: Optional[str] = None @@ -54,15 +68,43 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + gen_config = Metadata.load_generation_config(model_path) # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + if gen_config: + metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) + metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k) + metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p) + metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p) + metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold) + metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp) + metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta) + # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) + metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence) + metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k) + metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p) + metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p) + metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold) + metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp) + metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) @@ -172,6 +214,23 @@ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + generation_config_path = model_path / "generation_config.json" + + if not generation_config_path.is_file(): + return {} + + try: + with open(generation_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + # not all models have valid generation_config.json + return {} + @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number @@ -546,6 +605,32 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): assert self.name is not None + + if self.sampling_sequence is not None: + gguf_writer.add_sampling_sequence(self.sampling_sequence) + if self.sampling_top_k is not None: + gguf_writer.add_sampling_top_k(self.sampling_top_k) + if self.sampling_top_p is not None: + gguf_writer.add_sampling_top_p(self.sampling_top_p) + if self.sampling_min_p is not None: + gguf_writer.add_sampling_min_p(self.sampling_min_p) + if self.sampling_xtc_probability is not None: + gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability) + if self.sampling_xtc_threshold is not None: + gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold) + if self.sampling_temp is not None: + gguf_writer.add_sampling_temp(self.sampling_temp) + if self.sampling_penalty_last_n is not None: + gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n) + if self.sampling_penalty_repeat is not None: + gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat) + if self.sampling_mirostat is not None: + gguf_writer.add_sampling_mirostat(self.sampling_mirostat) + if self.sampling_mirostat_tau is not None: + gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau) + if self.sampling_mirostat_eta is not None: + gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta) + gguf_writer.add_name(self.name) if self.author is not None: diff --git a/include/llama.h b/include/llama.h index 8547226ff21..b52eaacfa7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -246,6 +246,21 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_STR, }; + enum llama_model_meta_key { + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, + }; + struct llama_model_kv_override { enum llama_model_kv_override_type tag; @@ -518,6 +533,9 @@ extern "C" { // Get the number of metadata key/value pairs LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + // Get sampling metadata key name. Returns nullptr if the key is invalid + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); + // Get metadata key name by index LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index c9056b59c70..a879940eaee 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -781baf2a14d9e0aaee542b2e1bb918bfc4132199 +55bc9320a4aae82af18e23eefd5de319a755d7b9 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8ec95ee1762..f7a8c9841ec 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -115,6 +115,7 @@ add_library(llama models/qwen3vl-moe.cpp models/qwen3moe.cpp models/refact.cpp + models/rnd1.cpp models/rwkv6-base.cpp models/rwkv6.cpp models/rwkv6qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b2eb2477f93..7ef87acf1b3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -108,24 +108,37 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, + { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_TYPE, "general.type" }, - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_VERSION, "general.version" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" }, + { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" }, + { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" }, + { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" }, + { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" }, + { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" }, + { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, @@ -2446,6 +2459,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, }, }, + { + LLM_ARCH_RND1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2722,6 +2755,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index ae7fa222aca..9ad3157bf67 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -112,6 +112,7 @@ enum llm_arch { LLM_ARCH_APERTUS, LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, + LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_UNKNOWN, }; @@ -122,6 +123,18 @@ enum llm_kv { LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_SAMPLING_SEQUENCE, + LLM_KV_GENERAL_SAMPLING_TOP_K, + LLM_KV_GENERAL_SAMPLING_TOP_P, + LLM_KV_GENERAL_SAMPLING_MIN_P, + LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, + LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, + LLM_KV_GENERAL_SAMPLING_TEMP, + LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, + LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62dfc..2aa6d52a242 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1248,7 +1248,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm - if (!sorted_output) { + if (!sorted_output && n_outputs > 1) { GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6ff..1d012e09aba 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] // get top n_group_used expert groups group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] - ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] cb(expert_groups, "ffn_moe_group_topk", il); // mask out the other groups @@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e703181a198..a042ea9632c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1036,6 +1036,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_RND1: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1593,7 +1605,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { @@ -3401,6 +3414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_RND1: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4581,7 +4595,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - const bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -6718,7 +6733,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -6880,6 +6895,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { res = nullptr; } break; @@ -7073,6 +7089,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_RND1: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -7593,6 +7614,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: @@ -7665,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } +const char * llama_model_meta_key_str(llama_model_meta_key key) { + switch (key) { + case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold"; + case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta"; + default: return nullptr; + } +} + int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb6..0b41f7ba8eb 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -4,7 +4,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); diff --git a/src/models/models.h b/src/models/models.h index 4d7aeb4f42c..5f019c59be8 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -431,6 +431,10 @@ struct llm_build_refact : public llm_graph_context { llm_build_refact(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_rnd1 : public llm_graph_context { + llm_build_rnd1(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_rwkv6 : public llm_build_rwkv6_base { llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp new file mode 100644 index 00000000000..46b3dc3efca --- /dev/null +++ b/src/models/rnd1.cpp @@ -0,0 +1,126 @@ +#include "models.h" + +// RND1 is a Qwen3Moe AR model converted to diffusion model. +llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Non-causal attention for diffusion + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2bb4b122247..2fc7d4c3c72 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) +static double jdst(const int32_t * a, const int32_t * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; + + for (size_t i = 0; i < n; ++i) { + set_a[a[i]]++; + set_b[b[i]]++; + } + + size_t diff = 0; + + for (const auto & p : set_a) { + const int64_t na = p.second; + const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0; + + diff += std::abs(na - nb); + } + + for (const auto & p : set_b) { + if (set_a.find(p.first) == set_a.end()) { + diff += p.second; + } + } + + return (double) diff / (2*n); +} + // maximum absolute asymmetry between a and b // asymmetry: (a - b) / (a + b) // This is more stable than relative error if one of the values fluctuates towards zero. @@ -1051,6 +1080,14 @@ struct test_case { return 1e-4; } + virtual double max_err() { + return max_nmse_err(); + } + + virtual double err(const float * a, const float * b, size_t n) { + return nmse(a, b, n); + } + virtual float grad_eps() { return 1e-1f; } @@ -1257,16 +1294,16 @@ struct test_case { // compare struct callback_userdata { bool ok; - double max_err; + test_case * tc; ggml_backend_t backend1; ggml_backend_t backend2; }; callback_userdata ud { true, - max_nmse_err(), + this, backend1, - backend2 + backend2, }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { @@ -1314,9 +1351,9 @@ struct test_case { } } - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + double err = ud->tc->err(f1.data(), f2.data(), f1.size()); + if (err > ud->tc->max_err()) { + printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err()); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -4943,7 +4980,71 @@ struct test_argsort : public test_case { } }; -struct test_topk_moe: public test_case { +// GGML_OP_TOP_K +struct test_top_k : public test_case { + const ggml_type type; + const std::array ne; + const int k; + + std::string vars() override { + return VARS_TO_STR3(type, ne, k); + } + + test_top_k(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + int k = 4) + : type(type), ne(ne), k(k) {} + + double max_err() override { + return 0.0; + } + + double err(const float * a, const float * b, size_t n) override { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_top_k(ctx, a, k); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } +}; + +struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; @@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case { ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] @@ -6953,9 +7054,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); - for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { for (bool use_view_slice : { true, false }) { for (std::array ne : std::initializer_list>{ {2, 1, 1, 1}, {2, 1, 3, 5}, {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) { @@ -7532,6 +7635,23 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int k : {1, 2, 3, 7, 15}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k)); + } + + // exhaustive top_k tests + //for (int i = 1; i < 9999; ++i) { + // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); + //} + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); @@ -7819,6 +7939,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); } } @@ -7827,6 +7948,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); } } @@ -7837,6 +7959,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); } } @@ -7854,6 +7977,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { @@ -7906,6 +8032,7 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); return test_cases; } diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 1fccfdd17f1..7fbca320162 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -13,9 +13,14 @@ endif() set(TARGET_SRCS server.cpp - utils.hpp server-http.cpp server-http.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/README.md b/tools/server/README.md index 8fd478eb328..c4e1759e4fd 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching @@ -1343,6 +1344,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r }' ``` +### POST `/v1/messages`: Anthropic-compatible Messages API + +Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps. + +*Options:* + +See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag. + +`model`: Model identifier (required) + +`messages`: Array of message objects with `role` and `content` (required) + +`max_tokens`: Maximum tokens to generate (default: 4096) + +`system`: System prompt as string or array of content blocks + +`temperature`: Sampling temperature 0-1 (default: 1.0) + +`top_p`: Nucleus sampling (default: 1.0) + +`top_k`: Top-k sampling + +`stop_sequences`: Array of stop sequences + +`stream`: Enable streaming (default: false) + +`tools`: Array of tool definitions (requires `--jinja`) + +`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`) + +*Examples:* + +```shell +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "x-api-key: your-api-key" \ + -d '{ + "model": "gpt-4", + "max_tokens": 1024, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +### POST `/v1/messages/count_tokens`: Token Counting + +Counts the number of tokens in a request without generating a response. + +Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required. + +*Example:* + +```shell +curl http://localhost:8080/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +*Response:* + +```json +{"input_tokens": 10} +``` + ## More examples ### Interactive mode diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 097c9440be2..ae25b6ddf72 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/utils.hpp b/tools/server/server-common.cpp similarity index 66% rename from tools/server/utils.hpp rename to tools/server/server-common.cpp index bf21726051e..2e3903b1b45 100644 --- a/tools/server/utils.hpp +++ b/tools/server/server-common.cpp @@ -1,487 +1,759 @@ -#pragma once - #include "common.h" #include "log.h" #include "llama.h" -#include "arg.h" // common_remote_get_content -#include "base64.hpp" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" +#include "arg.h" // for common_remote_get_content; TODO: use download.h only +#include "base64.hpp" -#define JSON_ASSERT GGML_ASSERT -#include +#include "server-common.h" #include #include -#include -#include -#include -#include - -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - -using json = nlohmann::ordered_json; - -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -using raw_buffer = std::vector; - -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); - return default_value; - } - } else { - return default_value; + +json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + case ERROR_TYPE_EXCEED_CONTEXT_SIZE: + type_str = "exceed_context_size_error"; + code = 400; + break; } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; } -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); +// +// random string / id +// -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; +std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} - server_grammar_trigger(const json & in) { - value.type = (common_grammar_trigger_type) in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token) in.at("token").get(); - } - } + std::random_device rd; + std::mt19937 generator(rd()); - json to_json() const { - json out { - {"type", (int) value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) value.token; - } - return out; + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; } -}; + + return result; +} + +std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} + +std::string gen_tool_call_id() { + return random_string(); +} // -// tokenizer and input processing utils +// lora utils // -static bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number_integer()) { +bool lora_all_alora(const std::vector & loras) { + bool found_alora = false; + for (const auto & lora : loras) { + if (lora.scale != 0) { + if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { return false; } + found_alora = true; } - return true; } - return false; + return found_alora; } -// is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json & data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto & e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next) { + + // This should always be called after determining that the two sets are + // _not_ equal. This assert is therefore some slightly wasted work and + // should be safe to remove as long as this method is called correctly. + GGML_ASSERT(!are_lora_equal(current, next)); + + return ( + !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || + !lora_all_alora(next)); +} + +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); } } - return false; + + return lora; } -// does array have any individual integers/tokens? -static bool json_is_array_and_contains_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (e.is_number_integer()) { - return true; - } - } +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { return false; } - return false; + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; } -// get value by path(key1 / key2) -static json json_get_nested_values(const std::vector & paths, const json & js) { - json result = json::object(); - - for (const std::string & path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string & k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; +std::vector lora_get_enabled_ids(const std::vector & loras) { + std::vector enabled_ids; + for (size_t i = 0; i < loras.size(); ++i) { + if (loras[i].scale > 0) { + enabled_ids.push_back(i); } } - return result; + return enabled_ids; } -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; +// +// base64 utils (TODO: use the base64::decode from base64.hpp) +// - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } +static inline raw_buffer base64_decode(const std::string & encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; - prompt_tokens.push_back(p.template get()); + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + raw_buffer ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); } + + i = 0; } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); } - return prompt_tokens; -} + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string& text) { - size_t len = text.size(); - if (len == 0) return 0; + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) return len - i; + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); } } - // If no cut-off multi-byte character is found, return full length - return len; + return ret; } // -// template utils +// server_tokens implementation // -// format infill task -static llama_tokens format_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt - ) { - // TODO: optimize this block by reducing memory allocations and movement +server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } +} - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - llama_tokens extra_tokens; - extra_tokens.reserve(n_ctx); +server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { +} - auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); +llama_pos server_tokens::pos_next() const { + if (!has_mtmd) { + return tokens.size(); + } - if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + llama_pos res = tokens.size(); - extra_tokens.push_back(llama_vocab_fim_rep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); } - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - const std::string text = json_value(chunk, "text", std::string()); - const std::string filename = json_value(chunk, "filename", std::string("tmp")); - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + return res; +} - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); +std::string server_tokens::str() const { + std::ostringstream oss; + oss << "tokens: "; + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; + if (t == LLAMA_TOKEN_NULL) { + oss << " "; } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - - extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + oss << t << " "; } - - const auto chunk_tokens = common_tokenize(vocab, text, false, false); - extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } + oss << "\n"; + oss << "image idx: "; + for (const auto & it : map_idx_to_media) { + oss << it.first << ", "; + } + return oss.str(); +} - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); +const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { + return it->second; } + throw std::runtime_error("Chunk not found"); +} - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); - const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); +void server_tokens::push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); +} - SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); +void server_tokens::push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } +} - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); +void server_tokens::push_back(server_tokens & tokens) { + size_t start_idx = size(); + for (size_t i = 0; i < tokens.size(); i++) { + push_back(tokens[i]); + } + if (tokens.has_mtmd) { + // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. + // We could also just check, but this will prevent silently dropping MTMD data. + GGML_ASSERT(has_mtmd); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto * chunk = tokens.map_idx_to_media[it->first].get(); + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); + } + } +} - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); +void server_tokens::insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); +} - tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); +const llama_tokens & server_tokens::get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; +} - auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; +void server_tokens::set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; +} - if (llama_vocab_get_add_bos(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); +void server_tokens::keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + // make sure we never remove tokens in the middle of an image + // note that the case where we keep a full image at the end is allowed: + // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL + if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); + } else { + ++it; + } + } } + tokens.resize(n); +} - SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_vocab_fim_mid(vocab)); - - return embd_inp; +std::string server_tokens::detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); } -// -// base64 utils (TODO: move to common in the future) -// +size_t server_tokens::get_common_prefix(const server_tokens & b) const { + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } -static inline bool is_base64(uint8_t c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} + return i; + } -static inline raw_buffer base64_decode(const std::string & encoded_string) { - int i = 0; - int j = 0; - int in_ = 0; + return max_idx; + } - int in_len = encoded_string.size(); + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); - raw_buffer ret; + GGML_ASSERT(a_chunk && b_chunk); - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop + continue; } - i = 0; + return i; } - } - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; + if (ai == bi) { + continue; } - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } + return i; + } - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + return max_idx; // all tokens are equal +} - for (j = 0; j < i - 1; j++) { - ret.push_back(char_array_3[j]); +bool server_tokens::validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + const auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; } } + return true; +} - return ret; +int32_t server_tokens::process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const { + const auto & chunk = find_chunk(idx); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past; // unused for now + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + pos, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_tokens_out = 0; + return result; + } + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); + return 0; } // -// random string / id +// tokenizer and input processing utils // -static std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - - std::random_device rd; - std::mt19937 generator(rd()); - - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; +bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; } - - return result; + return false; } -static std::string gen_chatcmplid() { - return "chatcmpl-" + random_string(); +bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; } -static std::string gen_tool_call_id() { - return random_string(); +bool json_is_array_and_contains_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (e.is_number_integer()) { + return true; + } + } + return false; + } + return false; } -// -// other common utils -// +json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; } -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); - } +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; - return ret; -} + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); } - return out; + return prompt_tokens; } -// format server-sent event (SSE), return the formatted string to send -// note: if data is a json array, it will be sent as multiple events, one per item -static std::string format_sse(const json & data) { +std::string format_anthropic_sse(const json & data) { std::ostringstream ss; - auto send_single = [&ss](const json & data) { - ss << "data: " << - safe_json_to_str(data) << - "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + + auto send_event = [&ss](const json & event_obj) { + if (event_obj.contains("event") && event_obj.contains("data")) { + ss << "event: " << event_obj.at("event").get() << "\n"; + ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; + } else { + ss << "data: " << safe_json_to_str(event_obj) << "\n\n"; + } }; if (data.is_array()) { - for (const auto & item : data) { - send_single(item); + for (const auto & event : data) { + send_event(event); } } else { - send_single(data); + send_event(data); } return ss.str(); } +size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +} + +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { + mtmd::bitmaps bitmaps; + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + // process prompt + std::vector inputs; + // multimodal + mtmd_input_text inp_txt = { + prompt.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + auto result = server_tokens(chunks, true); + return result; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * use tokenize_input_prompts() if the input could be an array. + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + */ +static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; + constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; + const bool has_mtmd = mctx != nullptr; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); + return server_tokens(tmp, false); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + llama_tokens tmp = json_prompt.get(); + return server_tokens(tmp, false); + } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { + // JSON object with prompt key. + if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { + if (!has_mtmd) + throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); + + // JSON object with prompt and multimodal key. + std::vector files; + for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { + files.push_back(base64_decode(entry)); + } + return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); + } else { + // Not multimodal, but contains a subobject. + llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); + return server_tokens(tmp, false); + } + } else { + throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); + } +} + +std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); + } + } else { + result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + // // OAI utils // // used by /completions endpoint -static json oaicompat_completion_params_parse(const json & body) { +json oaicompat_completion_params_parse(const json & body) { json llama_params; if (!body.contains("prompt")) { @@ -525,19 +797,8 @@ static json oaicompat_completion_params_parse(const json & body) { return llama_params; } -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; -}; - // used by /chat/completions endpoint -static json oaicompat_chat_params_parse( +json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ const oaicompat_parser_options & opt, std::vector & out_files) @@ -809,7 +1070,230 @@ static json oaicompat_chat_params_parse( return llama_params; } -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { +json convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + system_content += json_value(block, "text", std::string()); + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", "data:" + media_type + ";base64," + data} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + +json anthropic_params_from_json( + const json & body, + const oaicompat_parser_options & opt, + std::vector & out_files) +{ + json oai_body = convert_anthropic_to_oai(body); + return oaicompat_chat_params_parse(oai_body, opt, out_files); +} + +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -851,7 +1335,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } -static json format_response_rerank( +json format_response_rerank( const json & request, const json & ranks, bool is_tei_format, @@ -896,63 +1380,12 @@ static json format_response_rerank( return res; } -static bool is_valid_utf8(const std::string & str) { - const unsigned char* bytes = reinterpret_cast(str.data()); - const unsigned char* end = bytes + str.length(); - - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || - (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } - - return true; -} - -static json format_tokenizer_response(const json & tokens) { - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(const std::string & content) { - return json { - {"content", content} - }; -} -static json format_logit_bias(const std::vector & logit_bias) { - json data = json::array(); - for (const auto & lb : logit_bias) { - data.push_back(json{ - {"bias", lb.bias}, - {"token", lb.token}, - }); - } - return data; -} +// +// other utils +// -static std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -986,538 +1419,203 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } -static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { - if (l1.size() != l2.size()) { - return false; - } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { - return false; - } - } - return true; +std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); } -// get the ids of all enabled loras -static std::vector lora_get_enabled_ids(const std::vector & loras) { - std::vector enabled_ids; - for (size_t i = 0; i < loras.size(); ++i) { - if (loras[i].scale > 0) { - enabled_ids.push_back(i); - } +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); } - return enabled_ids; + + return ret; } -// check whether the given lora set has only aloras activated (empty => false) -static bool lora_all_alora(const std::vector & loras) { - bool found_alora = false; - for (const auto & lora : loras) { - if (lora.scale != 0) { - if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { - return false; - } - found_alora = true; - } - } - return found_alora; +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { + return tokens_to_str(ctx, tokens.begin(), tokens.end()); } -// if the two sets of loras are different, they require a cache clear unless the -// change is only from aloras to aloras. -static bool lora_should_clear_cache( - const std::vector & current, - const std::vector & next) { +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - // This should always be called after determining that the two sets are - // _not_ equal. This assert is therefore some slightly wasted work and - // should be safe to remove as long as this method is called correctly. - GGML_ASSERT(!are_lora_equal(current, next)); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } - return ( - !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || - !lora_all_alora(next)); + return out; } -// parse lora config from JSON request, returned a copy of lora_base with updated scale -static std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + }; - // set value - for (const auto & entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); + if (data.is_array()) { + for (const auto & item : data) { + send_single(item); } + } else { + send_single(data); } - return lora; + return ss.str(); } -// -// utils for interacting with libmtmd -// (may need to refactor in near future) -// - -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** index in tokens to the image chunk - // note: the order need to be in-sync with tokens - std::map map_idx_to_media; - - // list of tokens - // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk - // otherwise, it is a normal text token - // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list - // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] - // idx 0 1 2 3 4 5 6 7 8 9 10 - // pos 0 1 2 3 4 5 5 5 7 7 7 - // map_idx_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - // TODO: server_tokens should be copyable - remove this: - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } +bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); - server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; } } - server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { - } - - llama_pos pos_next() const { - if (!has_mtmd) { - return tokens.size(); - } + return true; +} - llama_pos res = tokens.size(); +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); - } + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); - return res; - } + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - // for debugging - std::string str() const { - std::ostringstream oss; - oss << "tokens: "; - for (size_t idx = 0; idx < tokens.size(); ++idx) { - llama_token t = tokens[idx]; - oss << "idx:" << idx << " "; - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image idx: "; - for (const auto & it : map_idx_to_media) { - oss << it.first << ", "; - } - return oss.str(); - } + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - const mtmd::input_chunk_ptr & find_chunk(size_t idx) const { - auto it = map_idx_to_media.find(idx); - if (it != map_idx_to_media.end()) { - return it->second; - } - throw std::runtime_error("Chunk not found"); + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); - void push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); - } + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk * chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); - size_t start_idx = tokens.size(); - for (size_t i = 0; i < n_tokens; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } else { - GGML_ABORT("Invalid chunk type"); - } - } + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - // appends server tokens, updates the media map. copies media chunks. - void push_back(server_tokens & tokens) { - size_t start_idx = size(); - for (size_t i = 0; i < tokens.size(); i++) { - push_back(tokens[i]); - } - if (tokens.has_mtmd) { - // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. - // We could also just check, but this will prevent silently dropping MTMD data. - GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { - auto * chunk = tokens.map_idx_to_media[it->first].get(); - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx + it->first] = std::move(new_chunk); - } + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); } - } - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t size() const { - return tokens.size(); - } - - bool empty() const { - return tokens.empty(); - } - - void clear() { - map_idx_to_media.clear(); - tokens.clear(); + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - // make sure we never remove tokens in the middle of an image - // note that the case where we keep a full image at the end is allowed: - // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL - if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { - size_t idx = it->first; - if (idx >= n) { - it = map_idx_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); - } + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - std::string detokenize(const llama_context * ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto & t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } - size_t get_common_prefix(const server_tokens & b) const { - const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - - if (!has_mtmd) { - for (size_t i = 0; i < max_idx; ++i) { - if (tokens[i] == b.tokens[i]) { - continue; - } - - return i; - } - - return max_idx; - } - - for (size_t i = 0; i < max_idx; ++i) { - const llama_token ai = tokens[i]; - const llama_token bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - - GGML_ASSERT(a_chunk && b_chunk); - - const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); - const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - - const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); - const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - - if (id_ai == id_bi && n_tok_a == n_tok_b) { - GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a - 1; // will be +1 by the for loop - continue; - } - - return i; - } - - if (ai == bi) { - continue; - } - - return i; - } + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - return max_idx; // all tokens are equal - } + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context * ctx) const { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - for (size_t i = 0; i < tokens.size(); ++i) { - const auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto & chunk = find_chunk(i); - size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - i += n_tokens - 1; // will be +1 by the for loop - } catch (const std::exception & e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; - } + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; - } -}; + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); -// Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t * data, size_t len) { - const uint64_t fnv_prime = 0x100000001b3ULL; - uint64_t hash = 0xcbf29ce484222325ULL; + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - for (size_t i = 0; i < len; ++i) { - hash ^= data[i]; - hash *= fnv_prime; + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } - return std::to_string(hash); -} -static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { - mtmd::bitmaps bitmaps; - for (auto & file : files) { - mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); - if (!bmp.ptr) { - throw std::runtime_error("Failed to load image or audio file"); - } - // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); - bmp.set_id(hash.c_str()); - bitmaps.entries.push_back(std::move(bmp)); - } - // process prompt - std::vector inputs; - // multimodal - mtmd_input_text inp_txt = { - prompt.c_str(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd::input_chunks chunks(mtmd_input_chunks_init()); - auto bitmaps_c_ptr = bitmaps.c_ptr(); - int32_t tokenized = mtmd_tokenize(mctx, - chunks.ptr.get(), - &inp_txt, - bitmaps_c_ptr.data(), - bitmaps_c_ptr.size()); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - auto result = server_tokens(chunks, true); - return result; -} + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * use tokenize_input_prompts() if the input could be an array. - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - */ -static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; - constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; - const bool has_mtmd = mctx != nullptr; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); - return server_tokens(tmp, false); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - llama_tokens tmp = json_prompt.get(); - return server_tokens(tmp, false); - } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { - // JSON object with prompt key. - if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { - if (!has_mtmd) - throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - // JSON object with prompt and multimodal key. - std::vector files; - for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { - files.push_back(base64_decode(entry)); - } - return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); - } else { - // Not multimodal, but contains a subobject. - llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); - return server_tokens(tmp, false); - } - } else { - throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); - } -} + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] - */ -static std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { - result.reserve(json_prompt.size()); - for (const auto & p : json_prompt) { - result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); - } - } else { - result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; + return embd_inp; } -// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. -static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) { +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc) { server_tokens result = {}; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); diff --git a/tools/server/server-common.h b/tools/server/server-common.h new file mode 100644 index 00000000000..c759b7145d0 --- /dev/null +++ b/tools/server/server-common.h @@ -0,0 +1,361 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include "chat.h" +#include "mtmd.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +using raw_buffer = std::vector; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); + return default_value; + } + } else { + return default_value; + } +} + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error + ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error +}; + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + +json format_error_response(const std::string & message, const enum error_type type); + +// +// random string / id +// + +std::string random_string(); +std::string gen_chatcmplid(); +std::string gen_tool_call_id(); + +// +// lora utils +// + +// check whether the given lora set has only aloras activated (empty => false) +bool lora_all_alora(const std::vector & loras); + +// if the two sets of loras are different, they require a cache clear unless the +// change is only from aloras to aloras. +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next); + +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data); + +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2); + +// get the ids of all enabled loras +std::vector lora_get_enabled_ids(const std::vector & loras); + +// +// server_tokens +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; + + // list of tokens + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + // TODO: server_tokens should be copyable - remove this: + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd); + server_tokens(const llama_tokens & tokens, bool has_mtmd); + + // for debugging + std::string str() const; + + llama_pos pos_next() const; + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; + + void push_back(llama_token tok); + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk); + + // appends server tokens, updates the media map. copies media chunks. + void push_back(server_tokens & tokens); + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens); + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const; + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id); + + size_t size() const { return tokens.size(); } + + bool empty() const { return tokens.empty(); } + + void clear() { + map_idx_to_media.clear(); + tokens.clear(); + } + + void keep_first(size_t n); + + std::string detokenize(const llama_context * ctx, bool special) const; + + size_t get_common_prefix(const server_tokens & b) const; + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const; + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const; +}; + + +// +// tokenizer and input processing utils +// + +bool json_is_array_of_numbers(const json & data); + +// is array having BOTH numbers & strings? +bool json_is_array_of_mixed_numbers_strings(const json & data); + +// does array have any individual integers/tokens? +bool json_is_array_and_contains_numbers(const json & data); + +// get value by path(key1 / key2) +json json_get_nested_values(const std::vector & paths, const json & js); + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special); + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +size_t validate_utf8(const std::string& text); + +// process mtmd prompt, return the server_tokens containing both text tokens and media chunks +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files); + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] + */ +std::vector tokenize_input_prompts( + const llama_vocab * vocab, + mtmd_context * mctx, + const json & json_prompt, + bool add_special, + bool parse_special); + +// +// OAI utils +// + +// used by /completions endpoint +json oaicompat_completion_params_parse(const json & body); + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + std::map chat_template_kwargs; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +json oaicompat_chat_params_parse( + json & body, /* openai api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files); + +// convert Anthropic Messages API format to OpenAI Chat Completions API format +json convert_anthropic_to_oai(const json & body); + +// used by Anthropic /v1/messages endpoint +json anthropic_params_from_json( + const json & body, /* anthropic messages api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files); + +// TODO: move it to server-task.cpp +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); + +// TODO: move it to server-task.cpp +json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts, + int top_n); + +// +// other utils +// + +std::vector get_token_probabilities(llama_context * ctx, int idx); + +std::string safe_json_to_str(const json & data); + +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); + +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); + +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_sse(const json & data); + +// format Anthropic-style SSE with event types +std::string format_anthropic_sse(const json & data); + +bool is_valid_utf8(const std::string & str); + +// +// formatting output responses +// TODO: move these to server-task.cpp +// + +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt); + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc); diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index bebe0b49c4f..fe532090100 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -1,6 +1,6 @@ -#include "utils.hpp" #include "common.h" #include "server-http.h" +#include "server-common.h" #include @@ -136,7 +136,7 @@ bool server_http_context::init(const common_params & params) { return true; } - // Check for API key in the header + // Check for API key in the Authorization header auto auth_header = req.get_header_value("Authorization"); std::string prefix = "Bearer "; @@ -147,6 +147,13 @@ bool server_http_context::init(const common_params & params) { } } + // Check for API key in the x-api-key header + auto x_api_key_header = req.get_header_value("X-Api-Key"); + + if (std::find(api_keys.begin(), api_keys.end(), x_api_key_header) != api_keys.end()) { + return true; // API key is valid + } + // API key is invalid or not provided res.status = 401; res.set_content( diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp new file mode 100644 index 00000000000..5a74fd76ac3 --- /dev/null +++ b/tools/server/server-queue.cpp @@ -0,0 +1,268 @@ +#include "server-task.h" +#include "server-queue.h" + +#include "log.h" + +#include + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +// +// server_queue +// + +int server_queue::post(server_task && task, bool front) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; +} + +int server_queue::post(std::vector && tasks, bool front) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; +} + +void server_queue::defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); +} + +int server_queue::get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; +} + +void server_queue::on_new_task(std::function callback) { + callback_new_task = std::move(callback); +} + +void server_queue::on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); +} + +void server_queue::pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); +} + +void server_queue::terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); +} + +void server_queue::start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } +} + +void server_queue::cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); +} + +// +// server_response +// + +void server_response::add_waiting_task_id(int id_task) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); +} + +void server_response::add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } +} + +void server_response::remove_waiting_task_id(int id_task) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); +} + +void server_response::remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } +} + +server_task_result_ptr server_response::recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + RES_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + RES_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); +} + +void server_response::send(server_task_result_ptr && result) { + RES_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + RES_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } +} + +void server_response::terminate() { + running = false; + condition_results.notify_all(); +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h new file mode 100644 index 00000000000..47ef58425ea --- /dev/null +++ b/tools/server/server-queue.h @@ -0,0 +1,110 @@ +#pragma once + +#include "server-task.h" + +#include +#include +#include +#include + +struct server_queue { +private: + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + +public: + // Add a new task to the end of the queue + int post(server_task && task, bool front = false); + + // multi-task version of post() + int post(std::vector && tasks, bool front = false); + + // Add a new task, but defer until one slot is available + void defer(server_task && task); + + // Get the next id for creating a new task + int get_new_id(); + + // Register function to process a new task + void on_new_task(std::function callback); + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback); + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task(); + + // end the start_loop routine + void terminate(); + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop(); + + // for metrics + size_t queue_tasks_deferred_size() { + std::unique_lock lock(mutex_tasks); + return queue_tasks_deferred.size(); + } + +private: + void cleanup_pending_task(int id_target); +}; + +struct server_response { +private: + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + +public: + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task); + + void add_waiting_tasks(const std::vector & tasks); + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task); + + // remove multiple tasks from waiting list + void remove_waiting_task_ids(const std::unordered_set & id_tasks); + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks); + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout); + + // single-task version of recv() + server_task_result_ptr recv(int id_task); + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result); + + // terminate the waiting loop + void terminate(); +}; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp new file mode 100644 index 00000000000..b18239211f6 --- /dev/null +++ b/tools/server/server-task.cpp @@ -0,0 +1,1474 @@ +#include "server-common.h" +#include "server-task.h" + +#include "common.h" +#include "llama.h" +#include "chat.h" +#include "sampling.h" +#include "json-schema-to-grammar.h" + +using json = nlohmann::ordered_json; + +// +// task_params +// + +json task_params::format_logit_bias(const std::vector & logit_bias) const { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +json task_params::to_json(bool only_metrics) const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + if (only_metrics) { + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; +} + +// +// server_task +// + +task_params server_task::params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + task_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) + task_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.antiprompt = params_base.antiprompt; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + auto stream_opt = json_value(data, "stream_options", json::object()); + params.include_usage = json_value(stream_opt, "include_usage", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.return_progress = json_value(data, "return_progress", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + } else { + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + } + common_reasoning_format reasoning_format = params_base.reasoning_format; + if (data.contains("reasoning_format")) { + reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + } + params.oaicompat_chat_syntax.reasoning_format = reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : logit_bias->items()) { + float bias; + const auto & key = el.key(); + const auto & value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + + params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); + if (params.sampling.ignore_eos) { + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + // set reverse prompt from cli args if not set in the request + if (params.antiprompt.empty()) { + params.antiprompt = defaults.antiprompt; + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; +} + +// +// result_timings +// + +json result_timings::to_json() const { + json base = { + {"cache_n", cache_n}, + + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; +} + +// +// result_prompt_progress +// +json result_prompt_progress::to_json() const { + return json { + {"total", total}, + {"cache", cache}, + {"processed", processed}, + {"time_ms", time_ms}, + }; +} + +static inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +// +// completion_token_output +// + +json completion_token_output::to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; +} + +json completion_token_output::probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; +} + +float completion_token_output::logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); +} + +std::vector completion_token_output::str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; +} + +// +// server_task_result_cmpl_final +// +json server_task_result_cmpl_final::to_json() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case OAICOMPAT_TYPE_ANTHROPIC: + return stream ? to_json_anthropic_stream() : to_json_anthropic(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result_cmpl_final::to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); +} + +json server_task_result_cmpl_final::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto & diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + + if (include_usage) { + // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage + // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + } + + if (timings.prompt_n >= 0) { + deltas.back().push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; +} + +json server_task_result_cmpl_final::to_json_anthropic() { + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + json content_blocks = json::array(); + + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + + if (!msg.content.empty()) { + content_blocks.push_back({ + {"type", "text"}, + {"text", msg.content} + }); + } + + for (const auto & tool_call : msg.tool_calls) { + json tool_use_block = { + {"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.name} + }; + + try { + tool_use_block["input"] = json::parse(tool_call.arguments); + } catch (const std::exception &) { + tool_use_block["input"] = json::object(); + } + + content_blocks.push_back(tool_use_block); + } + + json res = { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", content_blocks}, + {"model", oaicompat_model}, + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded} + }} + }; + + return res; +} + +json server_task_result_cmpl_final::to_json_anthropic_stream() { + json events = json::array(); + + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + bool has_text = !oaicompat_msg.content.empty(); + size_t num_tool_calls = oaicompat_msg.tool_calls.size(); + + bool text_block_started = false; + std::unordered_set tool_calls_started; + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + + if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { + const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; + + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", full_tool_call.id}, + {"name", full_tool_call.name} + }} + }} + }); + tool_calls_started.insert(diff.tool_call_index); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (has_text) { + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", 0} + }} + }); + } + + for (size_t i = 0; i < num_tool_calls; i++) { + size_t content_block_index = (has_text ? 1 : 0) + i; + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", content_block_index} + }} + }); + } + + events.push_back({ + {"event", "message_delta"}, + {"data", { + {"type", "message_delta"}, + {"delta", { + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} + }}, + {"usage", { + {"output_tokens", n_decoded} + }} + }} + }); + + events.push_back({ + {"event", "message_stop"}, + {"data", { + {"type", "message_stop"} + }} + }); + + return events; +} + +// +// server_task_result_cmpl_partial +// +json server_task_result_cmpl_partial::to_json() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + case OAICOMPAT_TYPE_ANTHROPIC: + return to_json_anthropic(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result_cmpl_partial::to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat_chat() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json & delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first || is_progress) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + auto & last_json = deltas[deltas.size() - 1]; + GGML_ASSERT(last_json.at("choices").size() >= 1); + + if (prob_output.probs.size() > 0) { + last_json.at("choices").at(0)["logprobs"] = json { + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + last_json.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + last_json.push_back({"prompt_progress", progress.to_json()}); + } + } + + return deltas; +} + +// +// server_task_result_embd +// +json server_task_result_embd::to_json() { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); +} + +json server_task_result_embd::to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; +} + +json server_task_result_embd::to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; +} + +// +// server_task_result_rerank +// +json server_task_result_rerank::to_json() { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; +} + +json server_task_result_cmpl_partial::to_json_anthropic() { + json events = json::array(); + bool first = (n_decoded == 1); + static bool text_block_started = false; + + if (first) { + text_block_started = false; + + events.push_back({ + {"event", "message_start"}, + {"data", { + {"type", "message_start"}, + {"message", { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", json::array()}, + {"model", oaicompat_model}, + {"stop_reason", nullptr}, + {"stop_sequence", nullptr}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", 0} + }} + }} + }} + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + + if (!diff.tool_call_delta.name.empty()) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name} + }} + }} + }); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + return events; +} + +// +// server_task_result_error +// +json server_task_result_error::to_json() { + json res = format_error_response(err_msg, err_type); + if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + res["n_prompt_tokens"] = n_prompt_tokens; + res["n_ctx"] = n_ctx; + } + return res; +} + +// +// server_task_result_metrics +// +json server_task_result_metrics::to_json() { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_tokens_max", n_tokens_max }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "slots", slots_data }, + }; +} + +// +// server_task_result_slot_save_load +// +json server_task_result_slot_save_load::to_json() { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; +} + +// +// server_task_result_slot_erase +// +json server_task_result_slot_erase::to_json() { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; +} + +// +// server_task_result_apply_lora +// + +json server_task_result_apply_lora::to_json() { + return json {{ "success", true }}; +} + +// +// server_prompt_cache +// +size_t server_prompt_cache::size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; +} + +size_t server_prompt_cache::n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; +} + +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; +} + +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; +} + +void server_prompt_cache::update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens_cur) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } +} diff --git a/tools/server/server-task.h b/tools/server/server-task.h new file mode 100644 index 00000000000..b96c00a96a5 --- /dev/null +++ b/tools/server/server-task.h @@ -0,0 +1,460 @@ +#pragma once + +#include "common.h" +#include "llama.h" + +#include +#include +#include + +// TODO: prevent including the whole server-common.h as we only use server_tokens +#include "server-common.h" + +using json = nlohmann::ordered_json; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, + OAICOMPAT_TYPE_ANTHROPIC, +}; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +struct task_params { + bool stream = true; + bool include_usage = false; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + bool return_progress = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + + // Embeddings + int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) + + json format_logit_bias(const std::vector & logit_bias) const; + json to_json(bool only_metrics = false) const; +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + int id_slot = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + task_params params; + server_tokens tokens; + + server_task_type type; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task() = default; + + server_task(server_task_type type) : type(type) {} + + int32_t n_tokens() const { + return tokens.size(); + } + + static task_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data); + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t cache_n = -1; + + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const; +}; + +struct result_prompt_progress { + int32_t total = 0; + int32_t cache = 0; + int32_t processed = 0; + int64_t time_ms = 0; + + json to_json() const; +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return true; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const; + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs); + + static float logarithm(float x); + + static std::vector str_to_bytes(const std::string & str); + +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + bool include_usage; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + task_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); + + json to_json_oaicompat_chat_stream(); + + json to_json_anthropic(); + + json to_json_anthropic_stream(); +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + bool is_progress = false; + completion_token_output prob_output; + result_timings timings; + result_prompt_progress progress; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); + + json to_json_anthropic(); +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; +}; + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + // for ERROR_TYPE_EXCEED_CONTEXT_SIZE + int32_t n_prompt_tokens = 0; + int32_t n_ctx = 0; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override; +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override; +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override; +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override; +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override; +}; + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const; + + size_t n_tokens() const; + + server_prompt * alloc(const server_prompt & prompt, size_t state_size); + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + + void update(); +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb60..c9027bacef2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,25 +1,21 @@ -#include "chat.h" -#include "utils.hpp" +#include "server-common.h" #include "server-http.h" +#include "server-task.h" +#include "server-queue.h" #include "arg.h" #include "common.h" -#include "json-schema-to-grammar.h" #include "llama.h" #include "log.h" #include "sampling.h" #include "speculative.h" #include "mtmd.h" +#include "mtmd-helper.h" #include -#include -#include #include #include -#include #include -#include -#include #include #include #include @@ -37,1589 +33,39 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, -}; - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error - ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct slot_params { - bool stream = true; - bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - bool return_progress = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters - - int64_t t_max_prompt_ms = -1; // TODO: implement - int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - - std::vector lora; - - std::vector antiprompt; - std::vector response_fields; - bool timings_per_token = false; - bool post_sampling_probs = false; - - struct common_params_sampling sampling; - struct common_params_speculative speculative; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - - // Embeddings - int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - json to_json(bool only_metrics = false) const { - std::vector samplers; - samplers.reserve(sampling.samplers.size()); - for (const auto & sampler : sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); - } - - if (only_metrics) { - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } - - auto grammar_triggers = json::array(); - for (const auto & trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", sampling.dry_sequence_breakers}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"stop", antiprompt}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"logit_bias", format_logit_bias(sampling.logit_bias)}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, - {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) - - // used by SERVER_TASK_TYPE_CANCEL - int id_target = -1; - int id_slot = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - server_tokens tokens; - - server_task_type type; - - // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE - struct slot_action { - int slot_id; - std::string filename; - std::string filepath; - }; - slot_action slot_action; - - // used by SERVER_TASK_TYPE_METRICS - bool metrics_reset_bucket = false; - - // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; - - server_task() = default; - - server_task(server_task_type type) : type(type) {} - - int32_t n_tokens() const { - return tokens.size(); - } - - static slot_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - slot_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.return_progress = json_value(data, "return_progress", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 0); - params.speculative.n_max = std::max(params.speculative.n_max, 0); - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = params_base.lora_adapters; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); - } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; - } - common_reasoning_format reasoning_format = params_base.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - params.sampling.logit_bias.clear(); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } else if (logit_bias != data.end() && logit_bias->is_object()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : logit_bias->items()) { - float bias; - const auto & key = el.key(); - const auto & value = el.value(); - if (value.is_number()) { - bias = value.get(); - } else if (value.is_boolean() && !value.get()) { - bias = -INFINITY; - } else { - continue; - } - - char *end; - llama_token tok = strtol(key.c_str(), &end, 10); - if (*end == 0) { - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else { - auto toks = common_tokenize(vocab, key, false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - - params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); - if (params.sampling.ignore_eos) { - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - // set reverse prompt from cli args if not set in the request - if (params.antiprompt.empty()) { - params.antiprompt = defaults.antiprompt; - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - - return params; - } - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - -struct result_timings { - int32_t cache_n = -1; - - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const { - json base = { - {"cache_n", cache_n}, - - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; - } -}; - -struct result_prompt_progress { - int32_t total = 0; - int32_t cache = 0; - int32_t processed = 0; - int64_t time_ms = 0; - - json to_json() const { - return json { - {"total", total}, - {"cache", cache}, - {"processed", processed}, - {"time_ms", time_ms}, - }; - } -}; - -struct server_task_result { - int id = -1; - int id_slot = -1; - virtual bool is_error() { - // only used by server_task_result_error - return false; - } - virtual bool is_stop() { - // only used by server_task_result_cmpl_* - return true; - } - virtual int get_index() { - return -1; - } - virtual json to_json() = 0; - virtual ~server_task_result() = default; -}; - -// using shared_ptr for polymorphism of server_task_result -using server_task_result_ptr = std::unique_ptr; - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - -struct completion_token_output { - llama_token tok; - float prob; - std::string text_to_send; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto & p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; - } - - static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto & p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; - } - - static float logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); - } - - static std::vector str_to_bytes(const std::string & str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; - } -}; - -struct server_task_result_cmpl_final : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - bool stream; - bool include_usage; - result_timings timings; - std::string prompt; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - stop_type stop = STOP_TYPE_NONE; - - bool post_sampling_probs; - std::vector probs_output; - std::vector response_fields; - - slot_params generation_params; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; - - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return true; // in stream mode, final responses are considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - json res = json { - {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? llama_tokens {} : tokens}, - {"id_slot", id_slot}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - {"generation_settings", generation_params.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - {"stop_type", stop_type_to_str(stop)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, - }; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json { - {"choices", json::array({ - json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json choice { - {"finish_reason", finish_reason}, - {"index", 0}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json { - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto & diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - - if (include_usage) { - // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage - // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - } - - if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); - } - - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat(); - } - - return deltas; - } -}; - -struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - int32_t n_decoded; - int32_t n_prompt_tokens; - - bool post_sampling_probs; - bool is_progress = false; - completion_token_output prob_output; - result_timings timings; - result_prompt_progress progress; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return false; // in stream mode, partial responses are not considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - // non-OAI-compat JSON - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_slot}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); - } - return res; - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json & delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first || is_progress) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - auto & last_json = deltas[deltas.size() - 1]; - GGML_ASSERT(last_json.at("choices").size() >= 1); - - if (prob_output.probs.size() > 0) { - last_json.at("choices").at(0)["logprobs"] = json { - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - last_json.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - last_json.push_back({"prompt_progress", progress.to_json()}); - } - } - - return deltas; - } -}; - -struct server_task_result_embd : server_task_result { - int index = 0; - std::vector> embedding; - - int32_t n_tokens; - - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? to_json_oaicompat() - : to_json_non_oaicompat(); - } - - json to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; - } - - json to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -struct server_task_result_rerank : server_task_result { - int index = 0; - float score = -1e6; - - int32_t n_tokens; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - case ERROR_TYPE_EXCEED_CONTEXT_SIZE: - type_str = "exceed_context_size_error"; - code = 400; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - -struct server_task_result_error : server_task_result { - int index = 0; - error_type err_type = ERROR_TYPE_SERVER; - std::string err_msg; - - // for ERROR_TYPE_EXCEED_CONTEXT_SIZE - int32_t n_prompt_tokens = 0; - int32_t n_ctx = 0; - - virtual bool is_error() override { - return true; - } - - virtual json to_json() override { - json res = format_error_response(err_msg, err_type); - if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - res["n_prompt_tokens"] = n_prompt_tokens; - res["n_ctx"] = n_ctx; - } - return res; - } -}; - -struct server_task_result_metrics : server_task_result { - int n_idle_slots; - int n_processing_slots; - int n_tasks_deferred; - int64_t t_start; - - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - // while we can also use std::vector this requires copying the slot object which can be quite messy - // therefore, we use json to temporarily store the slot.to_json() result - json slots_data = json::array(); - - virtual json to_json() override { - return json { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", n_tasks_deferred }, - { "t_start", t_start }, - - { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, - { "t_tokens_generation_total", t_tokens_generation_total }, - { "n_tokens_predicted_total", n_tokens_predicted_total }, - { "t_prompt_processing_total", t_prompt_processing_total }, - - { "n_tokens_max", n_tokens_max }, - - { "n_prompt_tokens_processed", n_prompt_tokens_processed }, - { "t_prompt_processing", t_prompt_processing }, - { "n_tokens_predicted", n_tokens_predicted }, - { "t_tokens_generation", t_tokens_generation }, - - { "n_decode_total", n_decode_total }, - { "n_busy_slots_total", n_busy_slots_total }, - - { "slots", slots_data }, - }; - } -}; - -struct server_task_result_slot_save_load : server_task_result { - std::string filename; - bool is_save; // true = save, false = load - - size_t n_tokens; - size_t n_bytes; - double t_ms; - - virtual json to_json() override { - if (is_save) { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", n_tokens }, - { "n_written", n_bytes }, - { "timings", { - { "save_ms", t_ms } - }}, - }; - } - - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; - } -}; - -struct server_task_result_slot_erase : server_task_result { - size_t n_erased; - - virtual json to_json() override { - return json { - { "id_slot", id_slot }, - { "n_erased", n_erased }, - }; - } -}; - -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; - - size_t size() const { - return data.size(); - } +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, }; -struct server_prompt { - server_tokens tokens; - - std::vector data; - - std::list checkpoints; - - size_t size() const { - size_t res = data.size(); - - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); - } - - return res; - } - - int n_tokens() const { - return tokens.size(); - } +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded }; -struct server_prompt_cache { - server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { - this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); - this->limit_tokens = limit_tokens; - } - - std::list states; - - // in bytes, 0 = no limit - size_t limit_size = 0; - - // in tokens, 0 = no limit - size_t limit_tokens = 0; - - size_t size() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.size(); - } - - return res; - } - - size_t n_tokens() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.n_tokens(); - } - - return res; - } - - server_prompt * alloc(const server_prompt & prompt, size_t state_size) { - // first check if the current state is contained fully in the cache - for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); - - if (cur_lcp_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); - return nullptr; - } - } - - // next, remove any cached prompts that are fully contained in the current prompt - for (auto it = states.begin(); it != states.end();) { - const int len = it->tokens.get_common_prefix(prompt.tokens); - - if (len == (int) it->tokens.size()) { - SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); - - it = states.erase(it); - } else { - ++it; - } - } - - std::vector state_data; - - // check if we can allocate enough memory for the new state - try { - state_data.resize(state_size); - } catch (const std::bad_alloc & e) { - SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); - - limit_size = std::max(1, 0.4*size()); - - SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); - - update(); - - return nullptr; - } - - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto & cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.data =*/ std::move(state_data), - /*.checkpoints =*/ prompt.checkpoints, - }; - - return &cur; - } - - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { - const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - - float f_keep_best = float(lcp_best) / prompt.tokens.size(); - float sim_best = float(lcp_best) / tokens_new.size(); - - SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - auto it_best = states.end(); - - // find the most similar cached prompt, that would also preserve the most context - for (auto it = states.begin(); it != states.end(); ++it) { - const int lcp_cur = it->tokens.get_common_prefix(tokens_new); - - const float f_keep_cur = float(lcp_cur) / it->tokens.size(); - const float sim_cur = float(lcp_cur) / tokens_new.size(); - - // don't trash large prompts - if (f_keep_cur < 0.25f) { - continue; - } - - if (f_keep_best < f_keep_cur && sim_best < sim_cur) { - f_keep_best = f_keep_cur; - sim_best = sim_cur; - - it_best = it; - } - } - - if (it_best != states.end()) { - SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); - - return false; - } - - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); - - states.erase(it_best); - } - - return true; +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; } +} - void update() { - if (limit_size > 0) { - // always keep at least one state, regardless of the limits - while (states.size() > 1 && size() > limit_size) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - // average size per token - const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); - - // dynamically increase the token limit if it can fit in the memory limit - const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; - - if (limit_tokens > 0) { - while (states.size() > 1 && n_tokens() > limit_tokens_cur) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", - limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); - - for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", - (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; } -}; +} struct server_slot { int id; @@ -2022,303 +468,6 @@ struct server_metrics { } }; -struct server_queue { - int id = 0; - bool running; - - // queues - std::deque queue_tasks; - std::deque queue_tasks_deferred; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_update_slots; - - // Add a new task to the end of the queue - int post(server_task && task, bool front = false) { - std::unique_lock lock(mutex_tasks); - GGML_ASSERT(task.id != -1); - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - const int task_id = task.id; - QUE_DBG("new task, id = %d, front = %d\n", task_id, front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - condition_tasks.notify_one(); - return task_id; - } - - // multi-task version of post() - int post(std::vector && tasks, bool front = false) { - std::unique_lock lock(mutex_tasks); - for (auto & task : tasks) { - if (task.id == -1) { - task.id = id++; - } - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - } - condition_tasks.notify_one(); - return 0; - } - - // Add a new task, but defer until one slot is available - void defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - QUE_DBG("defer task, id = %d\n", task.id); - queue_tasks_deferred.push_back(std::move(task)); - condition_tasks.notify_one(); - } - - // Get the next id for creating a new task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = std::move(callback); - } - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); - } - - // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task() { - std::unique_lock lock(mutex_tasks); - if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); - } - condition_tasks.notify_one(); - } - - // end the start_loop routine - void terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop() { - running = true; - - while (true) { - QUE_DBG("%s", "processing new tasks\n"); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.pop_front(); - lock.unlock(); - - QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); - } - - // all tasks in the current loop is processed, slots data is now ready - QUE_DBG("%s", "update slots\n"); - - callback_update_slots(); - - QUE_DBG("%s", "waiting for new tasks\n"); - { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - -private: - void cleanup_pending_task(int id_target) { - // no need lock because this is called exclusively by post() - auto rm_func = [id_target](const server_task & task) { - return task.id == id_target; - }; - queue_tasks.erase( - std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), - queue_tasks.end()); - queue_tasks_deferred.erase( - std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), - queue_tasks_deferred.end()); - } -}; - -struct server_response { - bool running = true; - - // for keeping track of all tasks waiting for the result - std::unordered_set waiting_task_ids; - - // the main result queue (using ptr for polymorphism) - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); - } - - void add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & task : tasks) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); - } - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); - } - - void remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & id_task : id_tasks) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - waiting_task_ids.erase(id_task); - } - } - - // This function blocks the thread until there is a response for one of the id_tasks - server_task_result_ptr recv(const std::unordered_set & id_tasks) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - return !queue_results.empty(); - }); - - for (size_t i = 0; i < queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // same as recv(), but have timeout in seconds - // if timeout is reached, nullptr is returned - server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { - while (true) { - std::unique_lock lock(mutex_results); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - - std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - if (cr_res == std::cv_status::timeout) { - return nullptr; - } - } - - // should never reach here - } - - // single-task version of recv() - server_task_result_ptr recv(int id_task) { - std::unordered_set id_tasks = {id_task}; - return recv(id_tasks); - } - - // Send a new result to a waiting id_task - void send(server_task_result_ptr && result) { - SRV_DBG("sending result for task id = %d\n", result->id); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - if (result->id == id_task) { - SRV_DBG("task id = %d pushed to result queue\n", result->id); - - queue_results.emplace_back(std::move(result)); - condition_results.notify_all(); - return; - } - } - } - - // terminate the waiting loop - void terminate() { - running = false; - condition_results.notify_all(); - } -}; - struct server_context { common_params params_base; @@ -3323,7 +1472,7 @@ struct server_context { res->slots_data = std::move(slots_data); res->n_idle_slots = n_idle_slots; res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); res->t_start = metrics.t_start; res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; @@ -4645,7 +2794,7 @@ struct server_routes { json default_generation_settings_for_props; { - slot_params params; + task_params params; params.sampling = ctx_server.params_base.sampling; @@ -4784,7 +2933,7 @@ struct server_routes { std::string prompt = json_value(data, "prompt", std::string()); std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_infill( + data["prompt"] = format_prompt_infill( ctx_server.vocab, data.at("input_prefix"), data.at("input_suffix"), @@ -4842,6 +2991,38 @@ struct server_routes { OAICOMPAT_TYPE_CHAT); }; + server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = anthropic_params_from_json( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + OAICOMPAT_TYPE_ANTHROPIC); + }; + + server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; + json body = json::parse(req.body); + + json body_parsed = anthropic_params_from_json( + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); + + res->ok({{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + // same with handle_chat_completions, but without inference part server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); @@ -4939,8 +3120,7 @@ struct server_routes { } } - const json data = format_tokenizer_response(tokens_response); - res->ok(data); + res->ok(json{{"tokens", std::move(tokens_response)}}); return res; }; @@ -4951,11 +3131,10 @@ struct server_routes { std::string content; if (body.count("tokens") != 0) { const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + content = tokens_to_str(ctx_server.ctx, tokens); } - const json data = format_detokenized_response(content); - res->ok(data); + res->ok(json{{"content", std::move(content)}}); return res; }; @@ -5009,7 +3188,7 @@ struct server_routes { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; @@ -5205,7 +3384,11 @@ struct server_routes { } // next responses are streamed - res->data = format_sse(first_result->to_json()); // to be sent immediately + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_sse(first_result->to_json()); // to be sent immediately + } res->status = 200; res->content_type = "text/event-stream"; res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { @@ -5225,7 +3408,10 @@ struct server_routes { // check if there is more data if (!rd.has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (oaicompat != OAICOMPAT_TYPE_NONE) { output = "data: [DONE]\n\n"; } else { output = ""; @@ -5244,7 +3430,14 @@ struct server_routes { // send the results json res_json = result->to_json(); if (result->is_error()) { - output = format_sse(json {{ "error", res_json }}); + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_sse(json {{ "error", res_json }}); + } SRV_DBG("%s", "error received during streaming, terminating stream\n"); return false; // terminate on error } else { @@ -5252,7 +3445,11 @@ struct server_routes { dynamic_cast(result.get()) != nullptr || dynamic_cast(result.get()) != nullptr ); - output = format_sse(res_json); + if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_sse(res_json); + } } // has next data, continue @@ -5460,10 +3657,10 @@ struct server_routes { } }; -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; +static std::function shutdown_handler; +static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; -inline void signal_handler(int signal) { +static inline void signal_handler(int signal) { if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough @@ -5565,6 +3762,8 @@ int main(int argc, char ** argv) { ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); @@ -5632,6 +3831,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.terminate(); }; + // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; diff --git a/tools/server/tests/unit/test_anthropic_api.py b/tools/server/tests/unit/test_anthropic_api.py new file mode 100644 index 00000000000..d55dd1d9454 --- /dev/null +++ b/tools/server/tests/unit/test_anthropic_api.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +import pytest +import base64 +import requests + +from utils import * + +server: ServerProcess + + +def get_test_image_base64() -> str: + """Get a test image in base64 format""" + # Use the same test image as test_vision_api.py + IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + response = requests.get(IMG_URL) + response.raise_for_status() + return base64.b64encode(response.content).decode("utf-8") + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-anthropic" + server.server_port = 8082 + server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 + + +@pytest.fixture +def vision_server(): + """Separate fixture for vision tests that require multimodal support""" + global server + server = ServerPreset.tinygemma3() + server.offline = False # Allow downloading the model + server.model_alias = "tinygemma3-anthropic" + server.server_port = 8083 # Different port to avoid conflicts + server.n_slots = 1 + return server + + +# Basic message tests + +def test_anthropic_messages_basic(): + """Test basic Anthropic messages endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Say hello"} + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}" + assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}" + assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}" + assert "content" in res.body, "Missing 'content' field" + assert isinstance(res.body["content"], list), "Content should be an array" + assert len(res.body["content"]) > 0, "Content array should not be empty" + assert res.body["content"][0]["type"] == "text", "First content block should be text" + assert "text" in res.body["content"][0], "Text content block missing 'text' field" + assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" + assert "usage" in res.body, "Missing 'usage' field" + assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" + assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" + assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" + assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens" + # Anthropic API should NOT include timings + assert "timings" not in res.body, "Anthropic API should not include timings field" + + +def test_anthropic_messages_with_system(): + """Test messages with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_messages_multipart_content(): + """Test messages with multipart content blocks""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": " the answer?"} + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_messages_conversation(): + """Test multi-turn conversation""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Streaming tests + +def test_anthropic_messages_streaming(): + """Test streaming messages""" + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 30, + "messages": [ + {"role": "user", "content": "Say hello"} + ], + "stream": True + }) + + events = [] + for data in res: + # Each event should have type and other fields + assert "type" in data, f"Missing 'type' in event: {data}" + events.append(data) + + # Verify event sequence + event_types = [e["type"] for e in events] + assert "message_start" in event_types, "Missing message_start event" + assert "content_block_start" in event_types, "Missing content_block_start event" + assert "content_block_delta" in event_types, "Missing content_block_delta event" + assert "content_block_stop" in event_types, "Missing content_block_stop event" + assert "message_delta" in event_types, "Missing message_delta event" + assert "message_stop" in event_types, "Missing message_stop event" + + # Check message_start structure + message_start = next(e for e in events if e["type"] == "message_start") + assert "message" in message_start, "message_start missing 'message' field" + assert message_start["message"]["type"] == "message" + assert message_start["message"]["role"] == "assistant" + assert message_start["message"]["content"] == [] + assert "usage" in message_start["message"] + assert message_start["message"]["usage"]["input_tokens"] > 0 + + # Check content_block_start + block_start = next(e for e in events if e["type"] == "content_block_start") + assert "index" in block_start, "content_block_start missing 'index'" + assert block_start["index"] == 0, "First content block should be at index 0" + assert "content_block" in block_start + assert block_start["content_block"]["type"] == "text" + + # Check content_block_delta + deltas = [e for e in events if e["type"] == "content_block_delta"] + assert len(deltas) > 0, "Should have at least one content_block_delta" + for delta in deltas: + assert "index" in delta + assert "delta" in delta + assert delta["delta"]["type"] == "text_delta" + assert "text" in delta["delta"] + + # Check content_block_stop + block_stop = next(e for e in events if e["type"] == "content_block_stop") + assert "index" in block_stop + assert block_stop["index"] == 0 + + # Check message_delta + message_delta = next(e for e in events if e["type"] == "message_delta") + assert "delta" in message_delta + assert "stop_reason" in message_delta["delta"] + assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"] + assert "usage" in message_delta + assert message_delta["usage"]["output_tokens"] > 0 + + # Check message_stop + message_stop = next(e for e in events if e["type"] == "message_stop") + # message_stop should NOT have timings for Anthropic API + assert "timings" not in message_stop, "Anthropic streaming should not include timings" + + +# Token counting tests + +def test_anthropic_count_tokens(): + """Test token counting endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello world"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + assert isinstance(res.body["input_tokens"], int) + assert res.body["input_tokens"] > 0 + # Should only have input_tokens, no other fields + assert "output_tokens" not in res.body + + +def test_anthropic_count_tokens_with_system(): + """Test token counting with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["input_tokens"] > 0 + + +def test_anthropic_count_tokens_no_max_tokens(): + """Test that count_tokens doesn't require max_tokens""" + server.start() + + # max_tokens is NOT required for count_tokens + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + + +# Tool use tests + +def test_anthropic_tool_use_basic(): + """Test basic tool use""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "tools": [{ + "name": "get_weather", + "description": "Get the current weather in a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + }], + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + # Check if model used the tool (it might not always, depending on the model) + content_types = [block.get("type") for block in res.body["content"]] + + if "tool_use" in content_types: + # Model used the tool + assert res.body["stop_reason"] == "tool_use" + + # Find the tool_use block + tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use") + assert "id" in tool_block + assert "name" in tool_block + assert tool_block["name"] == "get_weather" + assert "input" in tool_block + assert isinstance(tool_block["input"], dict) + + +def test_anthropic_tool_result(): + """Test sending tool results back + + This test verifies that tool_result blocks are properly converted to + role="tool" messages internally. Without proper conversion, this would + fail with a 500 error: "unsupported content[].type" because tool_result + blocks would remain in the user message content array. + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "content": "The weather is sunny, 25°C" + } + ] + } + ] + }) + + # This would be 500 with the old bug where tool_result blocks weren't converted + assert res.status_code == 200 + assert res.body["type"] == "message" + # Model should respond to the tool result + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + + +def test_anthropic_tool_result_with_text(): + """Test tool result mixed with text content + + This tests the edge case where a user message contains both text and + tool_result blocks. The server must properly split these into separate + messages: a user message with text, followed by tool messages. + Without proper handling, this would fail with 500: "unsupported content[].type" + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_1", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here are the results:"}, + { + "type": "tool_result", + "tool_use_id": "tool_1", + "content": "Sunny, 25°C" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_tool_result_error(): + """Test tool result with error flag""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "Get the weather"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "InvalidCity"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "is_error": True, + "content": "City not found" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_tool_streaming(): + """Test streaming with tool use""" + server.jinja = True + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "calculator", + "description": "Calculate math", + "input_schema": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + }], + "messages": [ + {"role": "user", "content": "Calculate 2+2"} + ] + }) + + events = [] + for data in res: + events.append(data) + + event_types = [e["type"] for e in events] + + # Should have basic events + assert "message_start" in event_types + assert "message_stop" in event_types + + # If tool was used, check for proper tool streaming + if any(e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use" + for e in events): + # Find tool use block start + tool_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use"] + + assert len(tool_starts) > 0, "Should have tool_use content_block_start" + + # Check index is correct (should be 0 if no text, 1 if there's text) + tool_start = tool_starts[0] + assert "index" in tool_start + assert tool_start["content_block"]["type"] == "tool_use" + assert "name" in tool_start["content_block"] + + +# Vision/multimodal tests + +def test_anthropic_vision_format_accepted(): + """Test that Anthropic vision format is accepted (format validation only)""" + server.start() + + # Small 1x1 red PNG image in base64 + red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": red_pixel_png + } + }, + { + "type": "text", + "text": "What is this?" + } + ] + } + ] + }) + + # Server accepts the format but tinyllama doesn't support images + # So it should return 500 with clear error message about missing mmproj + assert res.status_code == 500 + assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower() + + +def test_anthropic_vision_base64_with_multimodal_model(vision_server): + """Test vision with base64 image using Anthropic format with multimodal model""" + global server + server = vision_server + server.start() + + # Get test image in base64 format + image_base64 = get_test_image_base64() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_base64 + } + }, + { + "type": "text", + "text": "What is this:\n" + } + ] + } + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}" + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + # The model should generate some response about the image + assert len(res.body["content"][0]["text"]) > 0 + + +# Parameter tests + +def test_anthropic_stop_sequences(): + """Test stop_sequences parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Count to 10"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_temperature(): + """Test temperature parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "temperature": 0.5, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_p(): + """Test top_p parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_p": 0.9, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_k(): + """Test top_k parameter (llama.cpp specific)""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_k": 40, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Error handling tests + +def test_anthropic_missing_messages(): + """Test error when messages are missing""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50 + # missing "messages" field + }) + + # Should return an error (400 or 500) + assert res.status_code >= 400 + + +def test_anthropic_empty_messages(): + """Test permissive handling of empty messages array""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [] + }) + + # Server is permissive and accepts empty messages (provides defaults) + # This matches the permissive validation design choice + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Content block index tests + +def test_anthropic_streaming_content_block_indices(): + """Test that content block indices are correct in streaming""" + server.jinja = True + server.start() + + # Request that might produce both text and tool use + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param": {"type": "string"} + }, + "required": ["param"] + } + }], + "messages": [ + {"role": "user", "content": "Use the test tool"} + ] + }) + + events = [] + for data in res: + events.append(data) + + # Check content_block_start events have sequential indices + block_starts = [e for e in events if e.get("type") == "content_block_start"] + if len(block_starts) > 1: + # If there are multiple blocks, indices should be sequential + indices = [e["index"] for e in block_starts] + expected_indices = list(range(len(block_starts))) + assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}" + + # Check content_block_stop events match the starts + block_stops = [e for e in events if e.get("type") == "content_block_stop"] + start_indices = set(e["index"] for e in block_starts) + stop_indices = set(e["index"] for e in block_stops) + assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices" + + +# Extended features tests + +def test_anthropic_thinking(): + """Test extended thinking parameter""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "thinking": { + "type": "enabled", + "budget_tokens": 50 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_metadata(): + """Test metadata parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "metadata": { + "user_id": "test_user_123" + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Compatibility tests + +def test_anthropic_vs_openai_different_response_format(): + """Verify Anthropic format is different from OpenAI format""" + server.start() + + # Make OpenAI request + openai_res = server.make_request("POST", "/v1/chat/completions", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + # Make Anthropic request + anthropic_res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert openai_res.status_code == 200 + assert anthropic_res.status_code == 200 + + # OpenAI has "object", Anthropic has "type" + assert "object" in openai_res.body + assert "type" in anthropic_res.body + assert openai_res.body["object"] == "chat.completion" + assert anthropic_res.body["type"] == "message" + + # OpenAI has "choices", Anthropic has "content" + assert "choices" in openai_res.body + assert "content" in anthropic_res.body + + # Different usage field names + assert "prompt_tokens" in openai_res.body["usage"] + assert "input_tokens" in anthropic_res.body["usage"] + assert "completion_tokens" in openai_res.body["usage"] + assert "output_tokens" in anthropic_res.body["usage"] diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index e891f7efdc5..c736178fe22 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -29,6 +29,7 @@ sendMessage, stopGeneration } from '$lib/stores/chat.svelte'; + import { config } from '$lib/stores/settings.svelte'; import { supportsVision, supportsAudio, @@ -47,6 +48,7 @@ let { showCenteredEmpty = false } = $props(); + let disableAutoScroll = $derived(Boolean(config().disableAutoScroll)); let autoScrollEnabled = $state(true); let chatScrollContainer: HTMLDivElement | undefined = $state(); let dragCounter = $state(0); @@ -149,7 +151,7 @@ } function handleScroll() { - if (!chatScrollContainer) return; + if (disableAutoScroll || !chatScrollContainer) return; const { scrollTop, scrollHeight, clientHeight } = chatScrollContainer; const distanceFromBottom = scrollHeight - scrollTop - clientHeight; @@ -194,8 +196,10 @@ const extras = result?.extras; // Enable autoscroll for user-initiated message sending - userScrolledUp = false; - autoScrollEnabled = true; + if (!disableAutoScroll) { + userScrolledUp = false; + autoScrollEnabled = true; + } await sendMessage(message, extras); scrollChatToBottom(); @@ -241,6 +245,8 @@ } function scrollChatToBottom(behavior: ScrollBehavior = 'smooth') { + if (disableAutoScroll) return; + chatScrollContainer?.scrollTo({ top: chatScrollContainer?.scrollHeight, behavior @@ -248,14 +254,27 @@ } afterNavigate(() => { - setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); + if (!disableAutoScroll) { + setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); + } }); onMount(() => { - setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); + if (!disableAutoScroll) { + setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); + } }); $effect(() => { + if (disableAutoScroll) { + autoScrollEnabled = false; + if (scrollInterval) { + clearInterval(scrollInterval); + scrollInterval = undefined; + } + return; + } + if (isCurrentConversationLoading && autoScrollEnabled) { scrollInterval = setInterval(scrollChatToBottom, AUTO_SCROLL_INTERVAL); } else if (scrollInterval) { @@ -289,9 +308,11 @@ class="mb-16 md:mb-24" messages={activeMessages()} onUserAction={() => { - userScrolledUp = false; - autoScrollEnabled = true; - scrollChatToBottom(); + if (!disableAutoScroll) { + userScrolledUp = false; + autoScrollEnabled = true; + scrollChatToBottom(); + } }} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index d00ae128538..204f0d7551e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -3,7 +3,6 @@ Settings, Funnel, AlertTriangle, - Brain, Code, Monitor, Sun, @@ -58,6 +57,33 @@ label: 'Paste long text to file length', type: 'input' }, + { + key: 'enableContinueGeneration', + label: 'Enable "Continue" button', + type: 'checkbox', + isExperimental: true + }, + { + key: 'pdfAsImage', + label: 'Parse PDF as image', + type: 'checkbox' + }, + { + key: 'askForTitleConfirmation', + label: 'Ask for confirmation before changing conversation title', + type: 'checkbox' + } + ] + }, + { + title: 'Display', + icon: Monitor, + fields: [ + { + key: 'showThoughtInProgress', + label: 'Show thought in progress', + type: 'checkbox' + }, { key: 'showMessageStats', label: 'Show message generation statistics', @@ -79,25 +105,14 @@ type: 'checkbox' }, { - key: 'enableContinueGeneration', - label: 'Enable "Continue" button', - type: 'checkbox', - isExperimental: true - }, - { - key: 'pdfAsImage', - label: 'Parse PDF as image', + key: 'disableAutoScroll', + label: 'Disable automatic scroll', type: 'checkbox' }, { key: 'renderUserContentAsMarkdown', label: 'Render user content as Markdown', type: 'checkbox' - }, - { - key: 'askForTitleConfirmation', - label: 'Ask for confirmation before changing conversation title', - type: 'checkbox' } ] }, @@ -208,17 +223,6 @@ } ] }, - { - title: 'Reasoning', - icon: Brain, - fields: [ - { - key: 'showThoughtInProgress', - label: 'Show thought in progress', - type: 'checkbox' - } - ] - }, { title: 'Import/Export', icon: Database, diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 7e83d30f132..176a98b212f 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -8,6 +8,7 @@ import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; import { copyCodeToClipboard } from '$lib/utils/copy'; + import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; @@ -60,6 +61,7 @@ .use(remarkRehype) // Convert Markdown AST to rehype .use(rehypeKatex) // Render math using KaTeX .use(rehypeHighlight) // Add syntax highlighting + .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g.,
,
    ) inside Markdown tables .use(rehypeStringify); // Convert to HTML string }); diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index c25ea23f37b..6783757e6b4 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -14,6 +14,7 @@ export const SETTING_CONFIG_DEFAULT: Record = pasteLongTextToFileLen: 2500, pdfAsImage: false, showModelInfo: false, + disableAutoScroll: false, renderUserContentAsMarkdown: false, modelSelectorEnabled: false, // make sure these default values are in sync with `common.h` @@ -93,6 +94,8 @@ export const SETTING_CONFIG_INFO: Record = { 'Ask for confirmation before automatically changing conversation title when editing the first message.', pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', showModelInfo: 'Display the model name used to generate each message below the message content.', + disableAutoScroll: + 'Disable automatic scrolling while messages stream so you can control the viewport position manually.', renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', modelSelectorEnabled: 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', diff --git a/tools/server/webui/src/lib/constants/table-html-restorer.ts b/tools/server/webui/src/lib/constants/table-html-restorer.ts new file mode 100644 index 00000000000..e5d5b120114 --- /dev/null +++ b/tools/server/webui/src/lib/constants/table-html-restorer.ts @@ -0,0 +1,20 @@ +/** + * Matches
    ,
    ,
    tags (case-insensitive). + * Used to detect line breaks in table cell text content. + */ +export const BR_PATTERN = //gi; + +/** + * Matches a complete
      ...
    block. + * Captures the inner content (group 1) for further
  • extraction. + * Case-insensitive, allows multiline content. + */ +export const LIST_PATTERN = /^
      ([\s\S]*)<\/ul>$/i; + +/** + * Matches individual
    • ...
    • elements within a list. + * Captures the inner content (group 1) of each list item. + * Non-greedy to handle multiple consecutive items. + * Case-insensitive, allows multiline content. + */ +export const LI_PATTERN = /
    • ([\s\S]*?)<\/li>/gi; diff --git a/tools/server/webui/src/lib/markdown/table-html-restorer.ts b/tools/server/webui/src/lib/markdown/table-html-restorer.ts new file mode 100644 index 00000000000..918aa468116 --- /dev/null +++ b/tools/server/webui/src/lib/markdown/table-html-restorer.ts @@ -0,0 +1,181 @@ +/** + * Rehype plugin to restore limited HTML elements inside Markdown table cells. + * + * ## Problem + * The remark/rehype pipeline neutralizes inline HTML as literal text + * (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display + * as-is instead of being rendered. This causes
      and
        markup in + * table cells to show as plain text. + * + * ## Solution + * This plugin traverses the HAST post-conversion, parses whitelisted HTML + * patterns from text nodes, and replaces them with actual HAST element nodes + * that will be rendered as real HTML. + * + * ## Supported HTML + * - `
        ` / `
        ` / `
        ` - Line breaks (inline) + * - `
        • ...
        ` - Unordered lists (block) + * + * ## Key Implementation Details + * + * ### 1. Sibling Combination (Critical) + * The Markdown pipeline may fragment content across multiple text nodes and `
        ` + * elements. For example, `
        • a
        ` might arrive as: + * - Text: `"
          "` + * - Element: `
          ` + * - Text: `"
        • a
        "` + * + * We must combine consecutive text nodes and `
        ` elements into a single string + * before attempting to parse list markup. Without this, list detection fails. + * + * ### 2. visitParents for Deep Traversal + * Table cell content may be wrapped in intermediate elements (e.g., `

        ` tags). + * Using `visitParents` instead of direct child iteration ensures we find text + * nodes at any depth within the cell. + * + * ### 3. Reference Comparison for No-Op Detection + * When checking if `
        ` expansion changed anything, we compare: + * `expanded.length !== 1 || expanded[0] !== textNode` + * + * This catches both cases: + * - Multiple nodes created (text was split) + * - Single NEW node created (original had only `
        `, now it's an element) + * + * A simple `length > 1` check would miss the single `
        ` case. + * + * ### 4. Strict List Validation + * `parseList()` rejects malformed markup by checking for garbage text between + * `

      • ` elements. This prevents creating broken DOM from partial matches like + * `
          garbage
        • a
        `. + * + * ### 5. Newline Substitution for `
        ` in Combined String + * When combining siblings, existing `
        ` elements become `\n` in the combined + * string. This allows list content to span visual lines while still being parsed + * as a single unit. + * + * @example + * // Input Markdown: + * // | Feature | Notes | + * // |---------|-------| + * // | Multi-line | First
        Second | + * // | List |
        • A
        • B
        | + * // + * // Without this plugin:
        and
          render as literal text + * // With this plugin:
          becomes line break,
            becomes actual list + */ + +import type { Plugin } from 'unified'; +import type { Element, ElementContent, Root, Text } from 'hast'; +import { visit } from 'unist-util-visit'; +import { visitParents } from 'unist-util-visit-parents'; +import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer'; + +/** + * Expands text containing `
            ` tags into an array of text nodes and br elements. + */ +function expandBrTags(value: string): ElementContent[] { + const matches = [...value.matchAll(BR_PATTERN)]; + if (!matches.length) return [{ type: 'text', value } as Text]; + + const result: ElementContent[] = []; + let cursor = 0; + + for (const m of matches) { + if (m.index! > cursor) { + result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text); + } + result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element); + cursor = m.index! + m[0].length; + } + + if (cursor < value.length) { + result.push({ type: 'text', value: value.slice(cursor) } as Text); + } + + return result; +} + +/** + * Parses a `
            • ...
            ` string into a HAST element. + * Returns null if the markup is malformed or contains unexpected content. + */ +function parseList(value: string): Element | null { + const match = value.trim().match(LIST_PATTERN); + if (!match) return null; + + const body = match[1]; + const items: ElementContent[] = []; + let cursor = 0; + + for (const liMatch of body.matchAll(LI_PATTERN)) { + // Reject if there's non-whitespace between list items + if (body.slice(cursor, liMatch.index!).trim()) return null; + + items.push({ + type: 'element', + tagName: 'li', + properties: {}, + children: expandBrTags(liMatch[1] ?? '') + } as Element); + + cursor = liMatch.index! + liMatch[0].length; + } + + // Reject if no items found or trailing garbage exists + if (!items.length || body.slice(cursor).trim()) return null; + + return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element; +} + +/** + * Processes a single table cell, restoring HTML elements from text content. + */ +function processCell(cell: Element) { + visitParents(cell, 'text', (textNode: Text, ancestors) => { + const parent = ancestors[ancestors.length - 1]; + if (!parent || parent.type !== 'element') return; + + const parentEl = parent as Element; + const siblings = parentEl.children as ElementContent[]; + const startIndex = siblings.indexOf(textNode as ElementContent); + if (startIndex === -1) return; + + // Combine consecutive text nodes and
            elements into one string + let combined = ''; + let endIndex = startIndex; + + for (let i = startIndex; i < siblings.length; i++) { + const sib = siblings[i]; + if (sib.type === 'text') { + combined += (sib as Text).value; + endIndex = i; + } else if (sib.type === 'element' && (sib as Element).tagName === 'br') { + combined += '\n'; + endIndex = i; + } else { + break; + } + } + + // Try parsing as list first (replaces entire combined range) + const list = parseList(combined); + if (list) { + siblings.splice(startIndex, endIndex - startIndex + 1, list); + return; + } + + // Otherwise, just expand
            tags in this text node + const expanded = expandBrTags(textNode.value); + if (expanded.length !== 1 || expanded[0] !== textNode) { + siblings.splice(startIndex, 1, ...expanded); + } + }); +} + +export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => { + visit(tree, 'element', (node: Element) => { + if (node.tagName === 'td' || node.tagName === 'th') { + processCell(node); + } + }); +};