@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
21862186 std::copy_n (node->nb , GGML_MAX_DIMS, prop.nb );
21872187
21882188 for (int src = 0 ; src < GGML_MAX_SRC; ++src) {
2189- prop.src_address [src] = node->src [src] ? node->src [src]->data : nullptr ;
2189+ if (node->src [src]) {
2190+ prop.src_address [src] = node->src [src]->data ;
2191+ std::copy_n (node->src [src]->ne , GGML_MAX_DIMS, prop.src_ne [src]);
2192+ std::copy_n (node->src [src]->nb , GGML_MAX_DIMS, prop.src_nb [src]);
2193+ } else {
2194+ prop.src_address [src] = nullptr ;
2195+ std::fill_n (prop.src_ne [src], GGML_MAX_DIMS, 0 );
2196+ std::fill_n (prop.src_nb [src], GGML_MAX_DIMS, 0 );
2197+ }
21902198 }
21912199
21922200 memcpy (prop.op_params , node->op_params , GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
22062214 * @param graph_node_properties The stored properties of a CANN graph node.
22072215 * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
22082216 */
2209- static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2217+ static bool ggml_graph_node_has_matching_properties (
2218+ ggml_tensor * node,
2219+ ggml_graph_node_properties * graph_node_properties) {
22102220 if (node->data != graph_node_properties->node_address &&
2211- node->op != GGML_OP_VIEW) {
2221+ node->op != GGML_OP_VIEW) {
22122222 return false ;
22132223 }
2224+
22142225 if (node->op != graph_node_properties->node_op ) {
22152226 return false ;
22162227 }
2228+
22172229 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
22182230 if (node->ne [i] != graph_node_properties->ne [i]) {
22192231 return false ;
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
22222234 return false ;
22232235 }
22242236 }
2237+
22252238 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2226- if (node->src [i] &&
2227- node->src [i]->data != graph_node_properties->src_address [i] &&
2228- node->op != GGML_OP_VIEW
2229- ) {
2230- return false ;
2239+ if (node->src [i]) {
2240+ if (node->src [i]->data != graph_node_properties->src_address [i] &&
2241+ node->op != GGML_OP_VIEW) {
2242+ return false ;
2243+ }
2244+
2245+ for (int d = 0 ; d < GGML_MAX_DIMS; d++) {
2246+ if (node->src [i]->ne [d] != graph_node_properties->src_ne [i][d]) {
2247+ return false ;
2248+ }
2249+ if (node->src [i]->nb [d] != graph_node_properties->src_nb [i][d]) {
2250+ return false ;
2251+ }
2252+ }
2253+ } else {
2254+ if (graph_node_properties->src_address [i] != nullptr ) {
2255+ return false ;
2256+ }
22312257 }
22322258 }
2233- if (node-> op == GGML_OP_SCALE &&
2234- memcmp (graph_node_properties-> op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2235- return false ;
2259+
2260+ if (node-> op == GGML_OP_SCALE || node-> op == GGML_OP_UNARY || node->op == GGML_OP_GLU ) {
2261+ return memcmp (graph_node_properties-> op_params , node-> op_params , GGML_MAX_OP_PARAMS) == 0 ;
22362262 }
22372263 return true ;
22382264}
0 commit comments