Skip to content

Commit d3ab94e

Browse files
committed
Support graph encoding backend auto-detection in wasi-nn.
1 parent 4b42cfd commit d3ab94e

File tree

1 file changed

+81
-28
lines changed

1 file changed

+81
-28
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -491,20 +491,50 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
491491
goto fail;
492492
}
493493

494-
res = ensure_backend(instance, encoding, wasi_nn_ctx);
495-
if (res != success)
496-
goto fail;
494+
if (encoding == autodetect) {
495+
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
496+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
497+
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
498+
wasi_nn_ctx->backend_ctx);
499+
}
500+
501+
res = ensure_backend(instance, e, wasi_nn_ctx);
502+
if (res != success) {
503+
NN_ERR_PRINTF("continue trying the next");
504+
continue;
505+
}
506+
507+
call_wasi_nn_func(wasi_nn_ctx->backend, load, res,
508+
wasi_nn_ctx->backend_ctx, &builder_native, e,
509+
target, g);
510+
if (res != success) {
511+
NN_ERR_PRINTF("continue trying the next");
512+
continue;
513+
}
514+
515+
break;
516+
}
517+
}
518+
else {
519+
res = ensure_backend(instance, encoding, wasi_nn_ctx);
520+
if (res != success)
521+
goto fail;
497522

498-
call_wasi_nn_func(wasi_nn_ctx->backend, load, res, wasi_nn_ctx->backend_ctx,
499-
&builder_native, encoding, target, g);
500-
if (res != success)
501-
goto fail;
523+
call_wasi_nn_func(wasi_nn_ctx->backend, load, res,
524+
wasi_nn_ctx->backend_ctx, &builder_native, encoding,
525+
target, g);
526+
if (res != success)
527+
goto fail;
528+
}
502529

503530
fail:
504531
// XXX: Free intermediate structure pointers
505-
if (builder_native.buf)
532+
if (builder_native.buf) {
506533
wasm_runtime_free(builder_native.buf);
507-
unlock_ctx(wasi_nn_ctx);
534+
}
535+
if (wasi_nn_ctx) {
536+
unlock_ctx(wasi_nn_ctx);
537+
}
508538

509539
return res;
510540
}
@@ -565,17 +595,29 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
565595
goto fail;
566596
}
567597

568-
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
569-
if (res != success)
570-
goto fail;
598+
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
599+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
600+
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
601+
wasi_nn_ctx->backend_ctx);
602+
}
571603

572-
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
573-
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
574-
g);
575-
if (res != success)
576-
goto fail;
604+
res = ensure_backend(instance, e, wasi_nn_ctx);
605+
if (res != success) {
606+
NN_ERR_PRINTF("continue trying the next");
607+
continue;
608+
}
609+
610+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
611+
wasi_nn_ctx->backend_ctx, nul_terminated_name,
612+
name_len, g);
613+
if (res != success) {
614+
NN_ERR_PRINTF("continue trying the next");
615+
continue;
616+
}
617+
618+
break;
619+
}
577620

578-
res = success;
579621
fail:
580622
if (nul_terminated_name != NULL) {
581623
wasm_runtime_free(nul_terminated_name);
@@ -627,18 +669,29 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
627669
goto fail;
628670
}
629671

630-
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
631-
if (res != success)
632-
goto fail;
633-
;
672+
for (graph_encoding e = openvino; e <= unknown_backend; e++) {
673+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
674+
call_wasi_nn_func(wasi_nn_ctx->backend, deinit, res,
675+
wasi_nn_ctx->backend_ctx);
676+
}
634677

635-
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
636-
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
637-
nul_terminated_config, config_len, g);
638-
if (res != success)
639-
goto fail;
678+
res = ensure_backend(instance, e, wasi_nn_ctx);
679+
if (res != success) {
680+
NN_ERR_PRINTF("continue trying the next");
681+
continue;
682+
}
683+
684+
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
685+
wasi_nn_ctx->backend_ctx, nul_terminated_name,
686+
name_len, nul_terminated_config, config_len, g);
687+
if (res != success) {
688+
NN_ERR_PRINTF("continue trying the next");
689+
continue;
690+
}
691+
692+
break;
693+
}
640694

641-
res = success;
642695
fail:
643696
if (nul_terminated_name != NULL) {
644697
wasm_runtime_free(nul_terminated_name);

0 commit comments

Comments
 (0)