@@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
894894}
895895
896896/* *
897-  * @brief Get or expand a cached float32  tensor filled with a scalar value. 
897+  * @brief Get or expand a cached tensor filled with a scalar value. 
898898 * 
899-  * This function manages cached device memory for float32  tensors. If the current 
899+  * This function manages cached device memory for tensors. If the current 
900900 * cache size is insufficient for the requested tensor shape, the old memory will 
901-  * be released and new memory will be allocated. The allocated buffer is then 
902-  * initialized either with zeros (when @p value == 0.0f) or with the given scalar 
903-  * value using CANN operations. Finally, an aclTensor object is created from the 
904-  * cached memory and returned. 
901+  * be released and new memory will be allocated. The allocated buffer is 
902+  * initialized  with the given scalar value using CANN operations. 
903+  * Finally, an aclTensor object is created from the cached memory and returned. 
905904 * 
906905 * @param ctx           The CANN backend context that manages device memory. 
907906 * @param buffer        A pointer to the cached device buffer (will be allocated 
@@ -910,25 +909,27 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
910909 *                      updated when the cache is expanded. 
911910 * @param ne            The tensor shape array (number of elements in each dimension). 
912911 * @param nb            The stride size for each dimension. 
912+  * @param dtype         Data type of cached tensor. 
913913 * @param dims          The number of tensor dimensions. 
914914 * @param value         The scalar value used to fill the tensor (supports zero 
915915 *                      initialization via memset or arbitrary values via fill_scalar). 
916916 * @return              An aclTensor pointer created from the cached buffer. 
917917 */  
918- static  aclTensor* get_f32_cache_acl_tensor (
918+ static  aclTensor* get_cache_acl_tensor (
919919    ggml_backend_cann_context& ctx,
920920    void ** buffer,
921921    int64_t  &cache_element,
922922    int64_t * ne,
923923    size_t * nb,
924+     ggml_type dtype,
924925    int64_t  dims,
925926    float  value) {
926927    //  Calculate total number of elements
927928    int64_t  n_element = 1 ;
928929    for  (int  i = 0 ; i < dims; i++) {
929930        n_element *= ne[i];
930931    }
931-     size_t  size = n_element * sizeof ( float );
932+     size_t  size = n_element * ggml_type_size (dtype );
932933
933934    //  Allocate or expand cache if needed
934935    if  (cache_element < n_element) {
@@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor(
941942        cache_element = n_element;
942943
943944        //  Initialize cache
944-         if  (value == 0 .0f ) {
945-             ACL_CHECK (aclrtMemsetAsync (*buffer, size, 0 , size, ctx.stream ()));
946-         } else  {
947-             int64_t  pool_ne[1 ] = { n_element };
948-             size_t  pool_nb[1 ] = { sizeof (float ) };
949-             aclTensor* acl_value = ggml_cann_create_tensor (
950-                 *buffer, ACL_FLOAT, sizeof (float ), pool_ne, pool_nb, 1 );
951-             aclnn_fill_scalar (ctx, 1 , acl_value);
952-             ggml_cann_release_resources (ctx, acl_value);
953-         }
945+         int64_t  pool_ne[1 ] = { n_element };
946+         size_t  pool_nb[1 ] = { ggml_type_size (dtype) };
947+         aclTensor* acl_value = ggml_cann_create_tensor (
948+             *buffer, ggml_cann_type_mapping (dtype), ggml_type_size (dtype),
949+             pool_ne, pool_nb, 1 );
950+         aclnn_fill_scalar (ctx, value, acl_value);
951+         ggml_cann_release_resources (ctx, acl_value);
954952    }
955953
956-     return  ggml_cann_create_tensor (*buffer, ACL_FLOAT, sizeof (float ), ne, nb, dims);
954+     return  ggml_cann_create_tensor (*buffer, ggml_cann_type_mapping (dtype),
955+             ggml_type_size (dtype), ne, nb, dims);
957956}
958957
959958void  ggml_cann_rms_norm (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
965964    float  eps;
966965    memcpy (&eps, dst->op_params , sizeof (float ));
967966
968-     //  build gamma, one.. .
967+     //  build gamma.
969968    size_t  acl_gamma_nb[GGML_MAX_DIMS];
970-     acl_gamma_nb[0 ] = sizeof (float );
969+     //  gamma's type is the same with dst.
970+     acl_gamma_nb[0 ] = ggml_type_size (dst->type );
971971    for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
972972        acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
973973    }
974-     aclTensor* acl_gamma = get_f32_cache_acl_tensor (
974+     aclTensor* acl_gamma = get_cache_acl_tensor (
975975        ctx,
976976        &ctx.rms_norm_one_tensor_cache .cache ,
977977        ctx.rms_norm_one_tensor_cache .size ,
978978        src->ne ,
979979        acl_gamma_nb,
980+         dst->type ,
980981        1 ,        //  dims
981982        1 .0f       //  value
982983    );
983984
984-     //  build rstd, zero.. .
985+     //  build rstd.
985986    int64_t  acl_rstd_ne[] = {src->ne [1 ], src->ne [2 ], src->ne [3 ]};
986987    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1 ];
988+     //  rstd will always be F32.
987989    acl_rstd_nb[0 ] = sizeof (float );
988990    for  (int  i = 1 ; i < GGML_MAX_DIMS - 1 ; i++) {
989991        acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * acl_rstd_ne[i - 1 ];
990992    }
991-     aclTensor* acl_rstd = get_f32_cache_acl_tensor (
993+     aclTensor* acl_rstd = get_cache_acl_tensor (
992994        ctx,
993995        &ctx.rms_norm_zero_tensor_cache .cache ,
994996        ctx.rms_norm_zero_tensor_cache .size ,
995997        acl_rstd_ne,
996998        acl_rstd_nb,
999+         GGML_TYPE_F32,
9971000        GGML_MAX_DIMS - 1 ,
9981001        0 .0f       //  value
9991002    );
@@ -1765,41 +1768,42 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17651768    ggml_tensor* src0 = dst->src [0 ];  //  src
17661769    ggml_tensor* src1 = dst->src [1 ];  //  index
17671770
1771+     GGML_ASSERT (dst->type  == GGML_TYPE_F32 || dst->type  == GGML_TYPE_F16);
1772+ 
17681773    switch  (src0->type ) {
1769-         case  GGML_TYPE_F32: {
1770-             aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1771-                                 dst->data , dst->ne , dst->nb ,
1772-                                 src1, dst->type );
1773-             break ;
1774-         }
1775-         case  GGML_TYPE_F16: {
1776-             aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1777-             ggml_cann_pool_alloc src_buffer_allocator (
1778-                 ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1779-             void * src_trans_buffer = src_buffer_allocator.get ();
1780-             size_t  src_trans_nb[GGML_MAX_DIMS];
1781-             src_trans_nb[0 ] = sizeof (float );
1782-             for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
1783-                 src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1774+         case  GGML_TYPE_F16:
1775+         case  GGML_TYPE_F32:
1776+             if (src0->type  == dst->type ) {
1777+                 aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1778+                                     dst->data , dst->ne , dst->nb ,
1779+                                     src1, dst->type );
1780+             } else  {
1781+                 aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1782+                 ggml_cann_pool_alloc src_buffer_allocator (
1783+                     ctx.pool (), ggml_nelements (src0) * ggml_element_size (dst));
1784+                 void * src_trans_buffer = src_buffer_allocator.get ();
1785+                 size_t  src_trans_nb[GGML_MAX_DIMS];
1786+                 src_trans_nb[0 ] = dst->nb [0 ];
1787+                 for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
1788+                     src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1789+                 }
1790+                 aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1791+                     src_trans_buffer, ggml_cann_type_mapping (dst->type ), ggml_type_size (dst->type ),
1792+                     src0->ne , src_trans_nb, GGML_MAX_DIMS);
1793+                 aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1794+                 aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1795+                                     dst->data , dst->ne , dst->nb ,
1796+                                     src1, dst->type );
1797+                 ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17841798            }
1785-             aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1786-                 src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
1787-                 src0->ne , src_trans_nb, GGML_MAX_DIMS);
1788-             aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1789-             aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1790-                                 dst->data , dst->ne , dst->nb ,
1791-                                 src1, dst->type );
1792-             ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17931799            break ;
1794-         }
17951800        case  GGML_TYPE_Q8_0: {
17961801            //  add 1 dim for bcast mul.
17971802            size_t  weight_nb[GGML_MAX_DIMS + 1 ], scale_nb[GGML_MAX_DIMS + 1 ],
17981803                dequant_nb[GGML_MAX_DIMS + 1 ];
17991804            int64_t  weight_ne[GGML_MAX_DIMS + 1 ], scale_ne[GGML_MAX_DIMS + 1 ],
18001805                *dequant_ne;
18011806            int64_t  scale_offset = 0 ;
1802- 
18031807            //  [3,4,5,64] -> [3,4,5,2,32]
18041808            weight_ne[0 ] = QK8_0;
18051809            weight_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1809,7 +1813,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18091813                weight_ne[i] = src0->ne [i - 1 ];
18101814                weight_nb[i] = weight_nb[i - 1 ] * weight_ne[i - 1 ];
18111815            }
1812- 
18131816            //  [3,4,5,64] -> [3,4,5,2,1]
18141817            scale_ne[0 ] = 1 ;
18151818            scale_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1819,35 +1822,30 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18191822                scale_ne[i] = src0->ne [i - 1 ];
18201823                scale_nb[i] = scale_nb[i - 1 ] * scale_ne[i - 1 ];
18211824            }
1822- 
18231825            //  [3,4,5,64] -> [3,4,5,2,32]
18241826            dequant_ne = weight_ne;
1825-             dequant_nb[0 ] = sizeof ( float );
1827+             dequant_nb[0 ] = ggml_type_size (dst-> type );
18261828            for  (int  i = 1 ; i < GGML_MAX_DIMS + 1 ; i++) {
18271829                dequant_nb[i] = dequant_nb[i - 1 ] * dequant_ne[i - 1 ];
18281830            }
1829- 
18301831            scale_offset = ggml_nelements (src0) * sizeof (int8_t );
18311832            ggml_cann_pool_alloc dequant_buffer_allocator (
1832-                 ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1833- 
1833+                 ctx.pool (), ggml_nelements (src0) * ggml_type_size (dst->type ));
18341834            aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
18351835                src0->data , ACL_INT8, sizeof (int8_t ), weight_ne, weight_nb,
18361836                GGML_MAX_DIMS + 1 );
18371837            aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
18381838                src0->data , ACL_FLOAT16, sizeof (uint16_t ), scale_ne, scale_nb,
18391839                GGML_MAX_DIMS + 1 , ACL_FORMAT_ND, scale_offset);
18401840            aclTensor* dequant_tensor = ggml_cann_create_tensor (
1841-                 dequant_buffer_allocator.get (), ACL_FLOAT,  sizeof ( float ),
1841+                 dequant_buffer_allocator.get (), ggml_cann_type_mapping (dst-> type ),  ggml_type_size (dst-> type ),
18421842                dequant_ne, dequant_nb, GGML_MAX_DIMS + 1 );
1843- 
18441843            aclnn_mul (ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
1845-             dequant_nb[0 ] = sizeof ( float );
1844+             dequant_nb[0 ] = ggml_type_size (dst-> type );
18461845            dequant_ne = src0->ne ;
18471846            for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
18481847                dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
18491848            }
1850- 
18511849            aclnn_index_select_4d (ctx, dequant_buffer_allocator.get (),
18521850                                   dequant_ne, dequant_nb,
18531851                                   dst->data , dst->ne , dst->nb ,
@@ -1965,16 +1963,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
19651963    //  Only check env once.
19661964    static  bool  weight_to_nz = parse_bool (get_env (" GGML_CANN_WEIGHT_NZ" value_or (" on" 
19671965    if  (weight_to_nz && is_matmul_weight (weight)) {
1968-         int64_t  acl_stride[2 ] = {1 , transpose_ne[1 ]};
1969- 
1970-         //  Reverse ne.
1971-         std::reverse (transpose_ne, transpose_ne + n_dims);
1972- 
1973-         std::vector<int64_t > storageDims = {transpose_ne[0 ], transpose_ne[1 ]};
1974- 
1975-         acl_weight_tensor = aclCreateTensor (
1976-             transpose_ne, n_dims, ggml_cann_type_mapping (weight->type ), acl_stride,
1977-             0 , ACL_FORMAT_FRACTAL_NZ, storageDims.data (), 2 , weight->data );
1966+         acl_weight_tensor =
1967+             ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
19781968    } else  {
19791969        acl_weight_tensor =
19801970            ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -3178,7 +3168,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31783168        aclTensor* acl_src0_f16_tensor = nullptr ;
31793169        aclTensor* acl_src1_f16_tensor = nullptr ;
31803170        aclTensor* acl_src2_f16_tensor = nullptr ;
3181-         aclTensor* acl_dst_f16_tensor  = nullptr ;
31823171
31833172        //  Step 1: cast the src0 (Query) to fp16 if needed
31843173        ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
@@ -3216,22 +3205,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32163205        acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
32173206            src2_bsnd_nb, GGML_MAX_DIMS);
32183207
3219-         ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3220-         void * out_f16_buffer = out_f16_allocator.alloc (
3221-                                     ggml_nelements (dst) * faElemSize);
3222- 
3223-         int64_t * out_f16_ne = src0_bsnd_ne;
3224-         size_t  out_f16_nb[GGML_MAX_DIMS];
3225-         out_f16_nb[0 ] = faElemSize;
3226-         for (int  i = 1 ; i < GGML_MAX_DIMS; ++i){
3227-             out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3228-         }
3229- 
3230-         acl_dst_f16_tensor = ggml_cann_create_tensor (
3231-             out_f16_buffer, faDataType, faElemSize,
3232-             out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3233-         );
3234- 
32353208        //  Step 3: create the PSEShift tensor if needed
32363209        //          this tensor is considered as mask (f16) in the llama.cpp
32373210        aclTensor* bcast_pse_tensor = nullptr ;
@@ -3334,8 +3307,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33343307        int64_t  keyAntiquantMode = 0 ;
33353308        int64_t  valueAntiquantMode = 0 ;
33363309
3337-         //  Step 5: launch the FusedInferAttentionScoreV2 kernel.
3338-         //  Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3310+         GGML_ASSERT (dst->type  == GGML_TYPE_F32 || dst->type  == GGML_TYPE_F16);
3311+         aclTensor * fa_dst_tensor = nullptr ;
3312+         aclTensor * acl_dst_tensor = nullptr ;
3313+         ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3314+         if  (dst->type  == GGML_TYPE_F32) {
3315+             void * out_f16_buffer = out_f16_allocator.alloc (
3316+                                         ggml_nelements (dst) * faElemSize);
3317+ 
3318+             int64_t * out_f16_ne = src0_bsnd_ne;
3319+             size_t  out_f16_nb[GGML_MAX_DIMS];
3320+             out_f16_nb[0 ] = faElemSize;
3321+             for (int  i = 1 ; i < GGML_MAX_DIMS; ++i){
3322+                 out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3323+             }
3324+ 
3325+             fa_dst_tensor = ggml_cann_create_tensor (
3326+                 out_f16_buffer, faDataType, faElemSize,
3327+                 out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3328+             );
3329+         }
3330+         else  {
3331+             fa_dst_tensor = ggml_cann_create_tensor (dst);
3332+         }
33393333
33403334        GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
33413335            acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, //  q, k, v
@@ -3357,23 +3351,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33573351            blockSize, antiquantMode, //  blockSize, antiquantMode
33583352            softmaxLseFlag, //  softmaxLseFlag
33593353            keyAntiquantMode, valueAntiquantMode, //  keyAntiqMode, valueAntiqMode
3360-             acl_dst_f16_tensor , //  attentionOut
3354+             fa_dst_tensor , //  attentionOut
33613355            nullptr  //  softmaxLse
33623356        );
33633357
3364-         //  Step 6: post-processing, permute and cast to f32
3365-         aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3366-         //  TODO: when dst is fp16, don't need cast
3367-         aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3368-         ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3369-                                          acl_src1_f16_tensor,
3370-                                          acl_src2_f16_tensor,
3371-                                          acl_dst_f16_tensor,
3372-                                          acl_dst_tensor);
3373-         if (src3 != nullptr ){
3374-             ggml_cann_release_resources (ctx, bcast_pse_tensor);
3358+         if  (dst->type  == GGML_TYPE_F32) {
3359+             //  Step 6: post-processing, permute and cast to f32
3360+             aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3361+             aclnn_cast (ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
33753362        }
3376-     }else {
3363+ 
3364+         ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3365+                                         acl_src1_f16_tensor,
3366+                                         acl_src2_f16_tensor,
3367+                                         fa_dst_tensor,
3368+                                         acl_dst_tensor,
3369+                                         bcast_pse_tensor);
3370+ 
3371+     } else  {
33773372        GGML_ABORT (" Function is not implemented." 
33783373    }
33793374}
0 commit comments