|
|
@@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
+static wasi_nn_error
|
|
|
+ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|
|
+ WASINNContext **wasi_nn_ctx_ptr)
|
|
|
+{
|
|
|
+ wasi_nn_error res;
|
|
|
+
|
|
|
+ graph_encoding loaded_backend = autodetect;
|
|
|
+ if (!detect_and_load_backend(encoding, &loaded_backend)) {
|
|
|
+ res = invalid_encoding;
|
|
|
+ NN_ERR_PRINTF("load backend failed");
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
+ WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
+ if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
|
|
+ if (wasi_nn_ctx->backend != loaded_backend) {
|
|
|
+ res = unsupported_operation;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ wasi_nn_ctx->backend = loaded_backend;
|
|
|
+
|
|
|
+ /* init() the backend */
|
|
|
+ call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
|
+ &wasi_nn_ctx->backend_ctx);
|
|
|
+ if (res != success)
|
|
|
+ goto fail;
|
|
|
+
|
|
|
+ wasi_nn_ctx->is_backend_ctx_initialized = true;
|
|
|
+ }
|
|
|
+ *wasi_nn_ctx_ptr = wasi_nn_ctx;
|
|
|
+ return success;
|
|
|
+fail:
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
/* WASI-NN implementation */
|
|
|
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
@@ -410,6 +447,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
graph_encoding encoding, execution_target target, graph *g)
|
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
{
|
|
|
+ wasi_nn_error res;
|
|
|
+
|
|
|
NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
|
|
|
target);
|
|
|
|
|
|
@@ -417,7 +456,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
if (!instance)
|
|
|
return runtime_error;
|
|
|
|
|
|
- wasi_nn_error res;
|
|
|
graph_builder_array builder_native = { 0 };
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
if (success
|
|
|
@@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
goto fail;
|
|
|
}
|
|
|
|
|
|
- graph_encoding loaded_backend = autodetect;
|
|
|
- if (!detect_and_load_backend(encoding, &loaded_backend)) {
|
|
|
- res = invalid_encoding;
|
|
|
- NN_ERR_PRINTF("load backend failed");
|
|
|
- goto fail;
|
|
|
- }
|
|
|
-
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- wasi_nn_ctx->backend = loaded_backend;
|
|
|
-
|
|
|
- /* init() the backend */
|
|
|
- call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
|
- &wasi_nn_ctx->backend_ctx);
|
|
|
+ WASINNContext *wasi_nn_ctx;
|
|
|
+ res = ensure_backend(instance, encoding, &wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
goto fail;
|
|
|
|
|
|
@@ -473,6 +500,8 @@ wasi_nn_error
|
|
|
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|
|
graph *g)
|
|
|
{
|
|
|
+ wasi_nn_error res;
|
|
|
+
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
if (!instance) {
|
|
|
return runtime_error;
|
|
|
@@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|
|
|
|
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
|
|
|
|
|
- graph_encoding loaded_backend = autodetect;
|
|
|
- if (!detect_and_load_backend(autodetect, &loaded_backend)) {
|
|
|
- NN_ERR_PRINTF("load backend failed");
|
|
|
- return invalid_encoding;
|
|
|
- }
|
|
|
-
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- wasi_nn_ctx->backend = loaded_backend;
|
|
|
-
|
|
|
- wasi_nn_error res;
|
|
|
- /* init() the backend */
|
|
|
- call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
|
- &wasi_nn_ctx->backend_ctx);
|
|
|
+ WASINNContext *wasi_nn_ctx;
|
|
|
+ res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
return res;
|
|
|
|
|
|
@@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
|
|
int32_t name_len, char *config,
|
|
|
int32_t config_len, graph *g)
|
|
|
{
|
|
|
+ wasi_nn_error res;
|
|
|
+
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
if (!instance) {
|
|
|
return runtime_error;
|
|
|
@@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
|
|
|
|
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
|
|
|
|
|
- graph_encoding loaded_backend = autodetect;
|
|
|
- if (!detect_and_load_backend(autodetect, &loaded_backend)) {
|
|
|
- NN_ERR_PRINTF("load backend failed");
|
|
|
- return invalid_encoding;
|
|
|
- }
|
|
|
-
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- wasi_nn_ctx->backend = loaded_backend;
|
|
|
-
|
|
|
- wasi_nn_error res;
|
|
|
- /* init() the backend */
|
|
|
- call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
|
- &wasi_nn_ctx->backend_ctx);
|
|
|
+ WASINNContext *wasi_nn_ctx;
|
|
|
+ res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
return res;
|
|
|
|