diff --git a/src/kokoro_model.cpp b/src/kokoro_model.cpp index f73dddb..acf1168 100644 --- a/src/kokoro_model.cpp +++ b/src/kokoro_model.cpp @@ -1,5 +1,7 @@ #include "kokoro_model.h" +#define ggml_cast_if_needed(ctx, x, qtype) ((x)->type == (qtype) ? (x) : ggml_cast((ctx), (x), (qtype))) + static struct ggml_tensor * build_albert_attn_mask(ggml_context * ctx, struct kokoro_duration_context *kctx, const kokoro_ubatch & batch) { kctx->attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, (int64_t) batch.n_tokens, (int64_t) batch.n_tokens); ggml_set_input(kctx->attn_mask); @@ -943,7 +945,7 @@ struct ggml_cgraph * kokoro_duration_runner::build_kokoro_duration_graph(kokoro_ // In order to side step this problem I computed the graph and determined the size in advance and use that constant value here. struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 110000, false); - struct ggml_tensor * voice = model->voices[kctx->voice]; + struct ggml_tensor * voice = ggml_cast_if_needed(ctx, model->voices[kctx->voice], GGML_TYPE_F32); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -1146,7 +1148,7 @@ struct ggml_cgraph * kokoro_runner::build_kokoro_graph(kokoro_ubatch & batch) { // In order to side step this problem I computed the graph and determined the size in advance and use that constant value here. struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 570000, false); - struct ggml_tensor * voice = model->voices[kctx->voice]; + struct ggml_tensor * voice = ggml_cast_if_needed(ctx, model->voices[kctx->voice], GGML_TYPE_F32); struct ggml_tensor * style_half = ggml_view_1d(ctx, voice, voice->ne[0]/2, voice->ne[0] / 2 * voice->nb[0] + (batch.n_tokens - 3) * voice->nb[1]); struct ggml_tensor * cur; diff --git a/src/tts.cpp b/src/tts.cpp index f5faf28..6a0b07f 100644 --- a/src/tts.cpp +++ b/src/tts.cpp @@ -191,8 +191,7 @@ void update_conditional_prompt(tts_runner * runner, const std::string file_path, } bool kokoro_is_f16_compatible(std::string name) { - return name.find("voice_tensors") == std::string::npos && - name.find("bias") == std::string::npos && + return name.find("bias") == std::string::npos && name.find("gamma") == std::string::npos && name.find("beta") == std::string::npos && name.find("alpha") == std::string::npos &&