Browse Source

wasi-nn: fix backend leak on multiple loads (#4366)

cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4340
YAMAMOTO Takashi 9 tháng trước cách đây
mục cha
commit
0d001c4c38

+ 49 - 40
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
     return ret;
 }
 
+static wasi_nn_error
+ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
+               WASINNContext **wasi_nn_ctx_ptr)
+{
+    wasi_nn_error res;
+
+    graph_encoding loaded_backend = autodetect;
+    if (!detect_and_load_backend(encoding, &loaded_backend)) {
+        res = invalid_encoding;
+        NN_ERR_PRINTF("load backend failed");
+        goto fail;
+    }
+
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+    if (wasi_nn_ctx->is_backend_ctx_initialized) {
+        if (wasi_nn_ctx->backend != loaded_backend) {
+            res = unsupported_operation;
+            goto fail;
+        }
+    }
+    else {
+        wasi_nn_ctx->backend = loaded_backend;
+
+        /* init() the backend */
+        call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
+                          &wasi_nn_ctx->backend_ctx);
+        if (res != success)
+            goto fail;
+
+        wasi_nn_ctx->is_backend_ctx_initialized = true;
+    }
+    *wasi_nn_ctx_ptr = wasi_nn_ctx;
+    return success;
+fail:
+    return res;
+}
+
 /* WASI-NN implementation */
 
 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -410,6 +447,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
              graph_encoding encoding, execution_target target, graph *g)
 #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 {
+    wasi_nn_error res;
+
     NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
                   target);
 
@@ -417,7 +456,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
     if (!instance)
         return runtime_error;
 
-    wasi_nn_error res;
     graph_builder_array builder_native = { 0 };
 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
     if (success
@@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
         goto fail;
     }
 
-    graph_encoding loaded_backend = autodetect;
-    if (!detect_and_load_backend(encoding, &loaded_backend)) {
-        res = invalid_encoding;
-        NN_ERR_PRINTF("load backend failed");
-        goto fail;
-    }
-
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-    wasi_nn_ctx->backend = loaded_backend;
-
-    /* init() the backend */
-    call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
-                      &wasi_nn_ctx->backend_ctx);
+    WASINNContext *wasi_nn_ctx;
+    res = ensure_backend(instance, encoding, &wasi_nn_ctx);
     if (res != success)
         goto fail;
 
@@ -473,6 +500,8 @@ wasi_nn_error
 wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
                      graph *g)
 {
+    wasi_nn_error res;
+
     wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
     if (!instance) {
         return runtime_error;
@@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
 
     NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
 
-    graph_encoding loaded_backend = autodetect;
-    if (!detect_and_load_backend(autodetect, &loaded_backend)) {
-        NN_ERR_PRINTF("load backend failed");
-        return invalid_encoding;
-    }
-
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-    wasi_nn_ctx->backend = loaded_backend;
-
-    wasi_nn_error res;
-    /* init() the backend */
-    call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
-                      &wasi_nn_ctx->backend_ctx);
+    WASINNContext *wasi_nn_ctx;
+    res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
     if (res != success)
         return res;
 
@@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
                                  int32_t name_len, char *config,
                                  int32_t config_len, graph *g)
 {
+    wasi_nn_error res;
+
     wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
     if (!instance) {
         return runtime_error;
@@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
 
     NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
 
-    graph_encoding loaded_backend = autodetect;
-    if (!detect_and_load_backend(autodetect, &loaded_backend)) {
-        NN_ERR_PRINTF("load backend failed");
-        return invalid_encoding;
-    }
-
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-    wasi_nn_ctx->backend = loaded_backend;
-
-    wasi_nn_error res;
-    /* init() the backend */
-    call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
-                      &wasi_nn_ctx->backend_ctx);
+    WASINNContext *wasi_nn_ctx;
+    res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
     if (res != success)
         return res;
 

+ 1 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

@@ -10,6 +10,7 @@
 #include "wasm_export.h"
 
 typedef struct {
+    bool is_backend_ctx_initialized;
     bool is_model_loaded;
     graph_encoding backend;
     void *backend_ctx;