Skip to content

Commit 7f7d576

Browse files
committed
测试用的修改,对照区别用。
1 parent 5fa8bb1 commit 7f7d576

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

sense-voice/csrc/sense-voice-encoder.cc

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,23 +187,41 @@ static struct ggml_tensor *encoder_layer_sanm_forward(const sense_voice_hparams
187187
// fsmn_memory = ggml_concat(ctx0, fsmn_memory, fsmn_memory_batch, 2);
188188
// }
189189
// }
190+
bool flag = false;
190191
struct ggml_tensor * im2col = ggml_im2col(ctx0, a, ggml_reshape_4d(ctx0, b, b->ne[0], 1, b->ne[1] * b->ne[2], b->ne[3]), 1, 0, padding, 0, 1, 0, false, GGML_TYPE_F32);
191-
im2col = ggml_reshape_4d(ctx0, im2col, im2col->ne[0], im2col->ne[1], im2col->ne[2] / n_batch, n_batch);
192-
a = ggml_repeat(ctx0, ggml_cast(ctx0, a, GGML_TYPE_F32), ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, a->ne[0], a->ne[1], a->ne[2], n_batch));
193-
struct ggml_tensor * result = ggml_mul_mat(ctx0, a, im2col);
194-
fsmn_memory = ggml_reshape_3d(ctx0, result, im2col->ne[1], im2col->ne[2], im2col->ne[3]);
195-
// if(n_batch > 1){
196-
// printf("n_batch: %d\n", n_batch);
197-
// printf("a: %ld %ld %ld %ld\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
198-
// printf("b: %ld %ld %ld %ld\n", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
199-
// printf("im2col: %ld %ld %ld %ld\n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
192+
if(flag){
193+
im2col = ggml_reshape_4d(ctx0, im2col, im2col->ne[0], im2col->ne[1], im2col->ne[2] / n_batch, n_batch); //1
194+
a = ggml_repeat(ctx0, ggml_cast(ctx0, a, GGML_TYPE_F32), ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, a->ne[0], a->ne[1], a->ne[2], n_batch));//2
195+
// printf("a dims: [%d, %d, %d, %d], \n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
196+
// printf("im2col dims: [%d, %d, %d, %d], \n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
197+
struct ggml_tensor * result = ggml_mul_mat(ctx0, a, im2col);
198+
// printf("result dims: [%d, %d, %d, %d], \n", result->ne[0], result->ne[1], result->ne[2], result->ne[3]);
199+
fsmn_memory = ggml_reshape_3d(ctx0, result, im2col->ne[1], im2col->ne[2], im2col->ne[3]); //3
200+
// printf("true - fsmn_memory dims: [%d, %d, %d, %d]\n",
201+
// fsmn_memory->ne[0], fsmn_memory->ne[1],
202+
// fsmn_memory->ne[2], fsmn_memory->ne[3]);
203+
} else {
204+
a = ggml_repeat(ctx0, ggml_cast(ctx0, a, GGML_TYPE_F32), ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, a->ne[0], a->ne[1], im2col->ne[2]));//2
205+
struct ggml_tensor * result = ggml_mul_mat(ctx0, a, im2col);
206+
// a => [11, 1, 512, 1], b => [11, X, 2048 = 512 * batch, 1], result => [1, X, 2048, 1]
207+
// a => [11, 1, 512, 1] => [11, 1, 512, batch], im2col => [11, X, 512, batch], result => [1, X, 512, batch]
208+
// struct ggml_tensor * result = ggml_mul_mat(ctx0, ggml_cast(ctx0, a, GGML_TYPE_F32), im2col);
209+
fsmn_memory = ggml_reshape_3d(ctx0, result, im2col->ne[1], im2col->ne[2] / n_batch, im2col->ne[3] * n_batch); //4
210+
}
211+
if(n_batch > 1){
212+
printf("n_batch: %d\n", n_batch);
213+
printf("a: %ld %ld %ld %ld\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
214+
printf("b: %ld %ld %ld %ld\n", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
215+
printf("im2col: %ld %ld %ld %ld\n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
200216
// printf("result: %ld %ld %ld %ld\n", result->ne[0], result->ne[1], result->ne[2], result->ne[3]);
201-
// printf("fsmn_memory: %ld %ld %ld %ld\n", fsmn_memory->ne[0], fsmn_memory->ne[1], fsmn_memory->ne[2], fsmn_memory->ne[3]);
202-
// printf("V: %ld %ld %ld %ld\n", V->ne[0], V->ne[1], V->ne[2], V->ne[3]);
203-
// }
217+
printf("fsmn_memory: %ld %ld %ld %ld\n", fsmn_memory->ne[0], fsmn_memory->ne[1], fsmn_memory->ne[2], fsmn_memory->ne[3]);
218+
printf("V: %ld %ld %ld %ld\n", V->ne[0], V->ne[1], V->ne[2], V->ne[3]);
219+
}
204220
}
221+
205222
fsmn_memory = ggml_cont(ctx0, ggml_transpose(ctx0, fsmn_memory));
206-
fsmn_memory = ggml_add(ctx0, fsmn_memory, V);
223+
//fsmn_memory = ggml_cont(ctx0, fsmn_memory);
224+
207225
ggml_set_name(fsmn_memory, "fsmn_memory");
208226
}
209227

@@ -337,10 +355,12 @@ struct ggml_cgraph *sense_voice_build_graph_encoder(sense_voice_context &pctx,
337355
cur = encoder_layer_sanm_forward(hparams, pctx, ctx0, cur, model->encoder->encoder0, gf, pctx.params.flash_attn);
338356

339357
// encoders forward
358+
printf("begin layer==========================================> \n");
340359
for (int i=0; i < hparams.n_encoder_layers - 1; i++){
341360
cur = encoder_layer_sanm_forward(hparams, pctx, ctx0, cur, model->encoder->encoders_layer[i], gf, pctx.params.flash_attn);
342361
}
343362

363+
printf("end layer==========================================> \n");
344364
{
345365
// after encoder norm
346366
cur = ggml_norm(ctx0, cur, hparams.eps);
@@ -354,6 +374,7 @@ struct ggml_cgraph *sense_voice_build_graph_encoder(sense_voice_context &pctx,
354374
cur = encoder_layer_sanm_forward(hparams, pctx, ctx0, cur, model->encoder->tp_encoders_layer[i], gf, pctx.params.flash_attn);
355375
}
356376

377+
printf("end n_tp_encoder_layers==========================================> \n");
357378
{
358379
// tp encoder norm
359380
cur = ggml_norm(ctx0, cur, hparams.eps);

0 commit comments

Comments
 (0)