@@ -187,23 +187,41 @@ static struct ggml_tensor *encoder_layer_sanm_forward(const sense_voice_hparams
187187 // fsmn_memory = ggml_concat(ctx0, fsmn_memory, fsmn_memory_batch, 2);
188188 // }
189189 // }
190+ bool flag = false ;
190191 struct ggml_tensor * im2col = ggml_im2col (ctx0, a, ggml_reshape_4d (ctx0, b, b->ne [0 ], 1 , b->ne [1 ] * b->ne [2 ], b->ne [3 ]), 1 , 0 , padding, 0 , 1 , 0 , false , GGML_TYPE_F32);
191- im2col = ggml_reshape_4d (ctx0, im2col, im2col->ne [0 ], im2col->ne [1 ], im2col->ne [2 ] / n_batch, n_batch);
192- a = ggml_repeat (ctx0, ggml_cast (ctx0, a, GGML_TYPE_F32), ggml_new_tensor_4d (ctx0, GGML_TYPE_F16, a->ne [0 ], a->ne [1 ], a->ne [2 ], n_batch));
193- struct ggml_tensor * result = ggml_mul_mat (ctx0, a, im2col);
194- fsmn_memory = ggml_reshape_3d (ctx0, result, im2col->ne [1 ], im2col->ne [2 ], im2col->ne [3 ]);
195- // if(n_batch > 1){
196- // printf("n_batch: %d\n", n_batch);
197- // printf("a: %ld %ld %ld %ld\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
198- // printf("b: %ld %ld %ld %ld\n", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
199- // printf("im2col: %ld %ld %ld %ld\n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
192+ if (flag){
193+ im2col = ggml_reshape_4d (ctx0, im2col, im2col->ne [0 ], im2col->ne [1 ], im2col->ne [2 ] / n_batch, n_batch); // 1
194+ a = ggml_repeat (ctx0, ggml_cast (ctx0, a, GGML_TYPE_F32), ggml_new_tensor_4d (ctx0, GGML_TYPE_F16, a->ne [0 ], a->ne [1 ], a->ne [2 ], n_batch));// 2
195+ // printf("a dims: [%d, %d, %d, %d], \n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
196+ // printf("im2col dims: [%d, %d, %d, %d], \n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
197+ struct ggml_tensor * result = ggml_mul_mat (ctx0, a, im2col);
198+ // printf("result dims: [%d, %d, %d, %d], \n", result->ne[0], result->ne[1], result->ne[2], result->ne[3]);
199+ fsmn_memory = ggml_reshape_3d (ctx0, result, im2col->ne [1 ], im2col->ne [2 ], im2col->ne [3 ]); // 3
200+ // printf("true - fsmn_memory dims: [%d, %d, %d, %d]\n",
201+ // fsmn_memory->ne[0], fsmn_memory->ne[1],
202+ // fsmn_memory->ne[2], fsmn_memory->ne[3]);
203+ } else {
204+ a = ggml_repeat (ctx0, ggml_cast (ctx0, a, GGML_TYPE_F32), ggml_new_tensor_3d (ctx0, GGML_TYPE_F16, a->ne [0 ], a->ne [1 ], im2col->ne [2 ]));// 2
205+ struct ggml_tensor * result = ggml_mul_mat (ctx0, a, im2col);
206+ // a => [11, 1, 512, 1], b => [11, X, 2048 = 512 * batch, 1], result => [1, X, 2048, 1]
207+ // a => [11, 1, 512, 1] => [11, 1, 512, batch], im2col => [11, X, 512, batch], result => [1, X, 512, batch]
208+ // struct ggml_tensor * result = ggml_mul_mat(ctx0, ggml_cast(ctx0, a, GGML_TYPE_F32), im2col);
209+ fsmn_memory = ggml_reshape_3d (ctx0, result, im2col->ne [1 ], im2col->ne [2 ] / n_batch, im2col->ne [3 ] * n_batch); // 4
210+ }
211+ if (n_batch > 1 ){
212+ printf (" n_batch: %d\n " , n_batch);
213+ printf (" a: %ld %ld %ld %ld\n " , a->ne [0 ], a->ne [1 ], a->ne [2 ], a->ne [3 ]);
214+ printf (" b: %ld %ld %ld %ld\n " , b->ne [0 ], b->ne [1 ], b->ne [2 ], b->ne [3 ]);
215+ printf (" im2col: %ld %ld %ld %ld\n " , im2col->ne [0 ], im2col->ne [1 ], im2col->ne [2 ], im2col->ne [3 ]);
200216 // printf("result: %ld %ld %ld %ld\n", result->ne[0], result->ne[1], result->ne[2], result->ne[3]);
201- // printf("fsmn_memory: %ld %ld %ld %ld\n", fsmn_memory->ne[0], fsmn_memory->ne[1], fsmn_memory->ne[2], fsmn_memory->ne[3]);
202- // printf("V: %ld %ld %ld %ld\n", V->ne[0], V->ne[1], V->ne[2], V->ne[3]);
203- // }
217+ printf (" fsmn_memory: %ld %ld %ld %ld\n " , fsmn_memory->ne [0 ], fsmn_memory->ne [1 ], fsmn_memory->ne [2 ], fsmn_memory->ne [3 ]);
218+ printf (" V: %ld %ld %ld %ld\n " , V->ne [0 ], V->ne [1 ], V->ne [2 ], V->ne [3 ]);
219+ }
204220 }
221+
205222 fsmn_memory = ggml_cont (ctx0, ggml_transpose (ctx0, fsmn_memory));
206- fsmn_memory = ggml_add (ctx0, fsmn_memory, V);
223+ // fsmn_memory = ggml_cont(ctx0, fsmn_memory);
224+
207225 ggml_set_name (fsmn_memory, " fsmn_memory" );
208226 }
209227
@@ -337,10 +355,12 @@ struct ggml_cgraph *sense_voice_build_graph_encoder(sense_voice_context &pctx,
337355 cur = encoder_layer_sanm_forward (hparams, pctx, ctx0, cur, model->encoder ->encoder0 , gf, pctx.params .flash_attn );
338356
339357 // encoders forward
358+ printf (" begin layer==========================================> \n " );
340359 for (int i=0 ; i < hparams.n_encoder_layers - 1 ; i++){
341360 cur = encoder_layer_sanm_forward (hparams, pctx, ctx0, cur, model->encoder ->encoders_layer [i], gf, pctx.params .flash_attn );
342361 }
343362
363+ printf (" end layer==========================================> \n " );
344364 {
345365 // after encoder norm
346366 cur = ggml_norm (ctx0, cur, hparams.eps );
@@ -354,6 +374,7 @@ struct ggml_cgraph *sense_voice_build_graph_encoder(sense_voice_context &pctx,
354374 cur = encoder_layer_sanm_forward (hparams, pctx, ctx0, cur, model->encoder ->tp_encoders_layer [i], gf, pctx.params .flash_attn );
355375 }
356376
377+ printf (" end n_tp_encoder_layers==========================================> \n " );
357378 {
358379 // tp encoder norm
359380 cur = ggml_norm (ctx0, cur, hparams.eps );
0 commit comments