diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 2282534b0f..5b5a9db8d9 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -34,6 +34,8 @@ #define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION +#define TFLITE_MODEL_FILE_EXT ".tflite" +#define ONNX_MODEL_FILE_EXT ".onnx" /* Global variables */ static korp_mutex wasi_nn_lock; @@ -56,6 +58,12 @@ struct backends_api_functions { NN_ERR_PRINTF("Error %s() -> %d", #func, wasi_error); \ } while (0) +static graph_encoding auto_detect_encoding_order[] = { + tensorflowlite, onnx, openvino, tensorflow, + pytorch, ggml, autodetect, unknown_backend +}; +static int auto_detect_encoding_num = + sizeof(auto_detect_encoding_order) / sizeof(graph_encoding); static void *wasi_nn_key; static void @@ -491,20 +499,51 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, goto fail; } - res = ensure_backend(instance, encoding, wasi_nn_ctx); - if (res != success) - goto fail; + if (encoding == autodetect) { + for (int i = 0; i < auto_detect_encoding_num; i++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } + + res = ensure_backend(instance, auto_detect_encoding_order[i], + wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load, res, + wasi_nn_ctx->backend_ctx, &builder_native, + auto_detect_encoding_order[i], target, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } + } + else { + res = ensure_backend(instance, encoding, wasi_nn_ctx); + if (res != success) + goto fail; - call_wasi_nn_func(wasi_nn_ctx->backend, load, res, wasi_nn_ctx->backend_ctx, - &builder_native, encoding, target, g); - if (res != success) - goto fail; + call_wasi_nn_func(wasi_nn_ctx->backend, load, res, + wasi_nn_ctx->backend_ctx, &builder_native, encoding, + target, g); + if (res != success) + goto fail; + } fail: // XXX: Free intermediate structure pointers - if (builder_native.buf) + if (builder_native.buf) { wasm_runtime_free(builder_native.buf); - unlock_ctx(wasi_nn_ctx); + } + if (wasi_nn_ctx) { + unlock_ctx(wasi_nn_ctx); + } return res; } @@ -532,12 +571,25 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len, return success; } +static bool +ends_with(const char *str, const char *suffix) +{ + if (!str || !suffix) + return false; + uint32_t lenstr = strlen(str); + uint32_t lensuf = strlen(suffix); + if (lensuf > lenstr) + return false; + return strcmp(str + lenstr - lensuf, suffix) == 0; +} + wasi_nn_error wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, graph *g) { WASINNContext *wasi_nn_ctx = NULL; char *nul_terminated_name = NULL; + graph_encoding encoding = unknown_backend; wasi_nn_error res; wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); @@ -565,17 +617,56 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) - goto fail; + if (ends_with(nul_terminated_name, TFLITE_MODEL_FILE_EXT)) { + encoding = tensorflowlite; + } + else if (ends_with(nul_terminated_name, ONNX_MODEL_FILE_EXT)) { + encoding = onnx; + } + if (encoding == unknown_backend) { + for (int i = 0; i < auto_detect_encoding_num; i++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - g); - if (res != success) - goto fail; + res = ensure_backend(instance, auto_detect_encoding_order[i], + wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, nul_terminated_name, + name_len, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } + } + else { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } + + res = ensure_backend(instance, encoding, wasi_nn_ctx); + if (res != success) { + goto fail; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, nul_terminated_name, + name_len, g); + if (res != success) { + goto fail; + } + } - res = success; fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); @@ -594,6 +685,7 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, WASINNContext *wasi_nn_ctx = NULL; char *nul_terminated_name = NULL; char *nul_terminated_config = NULL; + graph_encoding encoding = unknown_backend; wasi_nn_error res; wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); @@ -627,18 +719,57 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); - if (res != success) - goto fail; - ; + if (ends_with(nul_terminated_name, TFLITE_MODEL_FILE_EXT)) { + encoding = tensorflowlite; + } + else if (ends_with(nul_terminated_name, ONNX_MODEL_FILE_EXT)) { + encoding = onnx; + } + if (encoding == unknown_backend) { + for (int i = 0; i < auto_detect_encoding_num; i++) { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } + + res = ensure_backend(instance, auto_detect_encoding_order[i], + wasi_nn_ctx); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, + res, wasi_nn_ctx->backend_ctx, + nul_terminated_name, name_len, + nul_terminated_config, config_len, g); + if (res != success) { + NN_ERR_PRINTF("continue trying the next"); + continue; + } + + break; + } + } + else { + if (wasi_nn_ctx->is_backend_ctx_initialized) { + call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res, + wasi_nn_ctx->backend_ctx); + } - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - nul_terminated_config, config_len, g); - if (res != success) - goto fail; + res = ensure_backend(instance, encoding, wasi_nn_ctx); + if (res != success) { + goto fail; + } + + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, + wasi_nn_ctx->backend_ctx, nul_terminated_name, + name_len, nul_terminated_config, config_len, g); + if (res != success) { + goto fail; + } + } - res = success; fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 9ac54e6644..0d9d55a1ed 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -136,6 +136,16 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, uint32_t size = builder->buf[0].size; + if (size < 8) { + NN_ERR_PRINTF("Model too small to be a valid TFLite file."); + return invalid_argument; + } + if (memcmp(tfl_ctx->models[*g].model_pointer + 4, "TFL3", 4) != 0) { + NN_ERR_PRINTF( + "Model file is not a TFLite FlatBuffer (missing TFL3 identifier)."); + return invalid_argument; + } + // Save model tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size); if (tfl_ctx->models[*g].model_pointer == NULL) {