From f9a9eb15e9248a47f7bde1210cb2ae479b3ba3b6 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 26 Sep 2025 03:51:08 +0000 Subject: [PATCH] Support FP16 as intermediate results in graph computation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit is a demo aimed at using FP16 as the data type for intermediate results in graph inference, reducing computation and improving inference speed. Verification was conducted with the CANN backend on Qwen2.5, Qwen3-MoE, and DeepSeek-Lite-V2, showing performance improvements of 3%–10% depending on the concurrency and model. The main changes include modifying operators involved in graph by replacing hardcoded FP32 data types with type inference based on input, adding FP16 support for GET_ROWS, and casting t_embd and t_logits back to FP32 at the end of inference. In fact, this is only a very basic validation. For full FP16 support, the following are still needed: 1. Modify all operators that currently hardcode FP32 to perform type inference based on the data type. 2. Add FP16 support to all backend operators. 3. Extend test cases to include FP16 data types. Co-authored-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cpu/ops.cpp | 20 ++++++++++++++++---- ggml/src/ggml-cpu/vec.h | 1 + ggml/src/ggml.c | 21 ++++++++++++++++----- src/llama-graph.cpp | 32 ++++++++++++++++++++++++++++++++ src/llama-graph.h | 2 ++ src/llama-model.cpp | 3 +++ 6 files changed, 70 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4f41ad..f7aa9282877d7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4580,9 +4580,15 @@ static void ggml_compute_forward_get_rows_f16( GGML_ASSERT(i01 >= 0 && i01 < ne01); - ggml_cpu_fp16_to_fp32( - (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + // Supports both F16 and F32 as dst type. + if (dst->type == GGML_TYPE_F16) + ggml_vec_cpy_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + else + ggml_cpu_fp16_to_fp32( + (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -4662,9 +4668,15 @@ static void ggml_compute_forward_get_rows_f32( GGML_ASSERT(i01 >= 0 && i01 < ne01); - ggml_vec_cpy_f32(nc, + // Supports both F16 and F32 as dst type. + if (dst->type == GGML_TYPE_F32) + ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + else + ggml_cpu_fp32_to_fp16( + (const float*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (ggml_fp16_t *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index ef334d089d1f7..db2266dfafc61 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -87,6 +87,7 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp } inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_cpy_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fe36bab8362b2..82f81b31ab02a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3024,7 +3024,10 @@ struct ggml_tensor * ggml_mul_mat( GGML_ASSERT(!ggml_is_transposed(a)); const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + // Tensor a is the weight, with its type determined by the model file. + // Tensor b is the activation, i.e., the intermediate computation result. + // Here, the destination type (dst) is kept the same as the input activation type. + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); result->op = GGML_OP_MUL_MAT; result->src[0] = a; @@ -3073,7 +3076,9 @@ struct ggml_tensor * ggml_mul_mat_id( GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + // Tensor b is the activation, i.e., the intermediate computation result. + // Here, the destination type (dst) is kept the same as the input activation type. + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); result->op = GGML_OP_MUL_MAT_ID; result->src[0] = as; @@ -3628,7 +3633,9 @@ struct ggml_tensor * ggml_get_rows( GGML_ASSERT(b->type == GGML_TYPE_I32); // TODO: implement non F32 return - enum ggml_type type = GGML_TYPE_F32; + // TODO: Automatically select the destination type based on parameters, + // environment variables, or backend support. Hard code F16 for example. + enum ggml_type type = GGML_TYPE_F16; if (a->type == GGML_TYPE_I32) { type = a->type; } @@ -3676,7 +3683,8 @@ struct ggml_tensor * ggml_set_rows( GGML_ASSERT(b->ne[2] % c->ne[1] == 0); GGML_ASSERT(b->ne[3] % c->ne[2] == 0); GGML_ASSERT(c->ne[3] == 1); - GGML_ASSERT(b->type == GGML_TYPE_F32); + // b->type also can be F16. + //GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous_rows(a)); @@ -5003,7 +5011,10 @@ struct ggml_tensor * ggml_flash_attn_ext( // permute(0, 2, 1, 3) int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + // The types of k and v are the same as those in the KV cache, + // while q is an intermediate computation result. + // Here, the destination type (dst) is kept the same as the type of q. + struct ggml_tensor * result = ggml_new_tensor(ctx, q->type, 4, ne); float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a4f..d1a8a991af611 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1937,6 +1937,38 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::cast_outputs() const { + ggml_tensor * ori_embd = res->t_embd; + if (cparams.embeddings && res->t_embd->type != GGML_TYPE_F32) { + ggml_tensor * embd = res->t_embd; + embd = ggml_cast(ctx0, embd, GGML_TYPE_F32); + cb(embd, "result_embd_cast", -1); + ggml_build_forward_expand(gf, embd); + res->t_embd = embd; + } + + if (cparams.embeddings && res->t_embd_pooled->type != GGML_TYPE_F32) { + // if LLAMA_POOLING_TYPE_NONE, embd_pooled == embd + if (res->t_embd_pooled == ori_embd) { + res->t_embd_pooled = res->t_embd; + } else { + ggml_tensor * embd_pooled = res->t_embd_pooled; + embd_pooled = ggml_cast(ctx0, embd_pooled, GGML_TYPE_F32); + cb(embd_pooled, "result_embd_pooled_cast", -1); + ggml_build_forward_expand(gf, embd_pooled); + res->t_embd_pooled = embd_pooled; + } + } + + if(res->t_logits->type != GGML_TYPE_F32) { + ggml_tensor * logits = res->t_logits; + logits = ggml_cast(ctx0, logits, GGML_TYPE_F32); + cb(logits, "result_logits_cast", -1); + ggml_build_forward_expand(gf, logits); + res->t_logits = logits; + } +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb043..cb97183feb74c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,6 +814,8 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + void cast_outputs() const; }; // TODO: better name diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2ae9abb4464fd..bf8ac3278cccb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -19618,6 +19618,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // cast output to F32 + llm->cast_outputs(); + return llm->res->get_gf(); }