Ver Fonte

wasi-nn: fix context lifetime issues (#4396)

* wasi-nn: fix context lifetime issues

use the module instance context api instead of trying to roll
our own with a hashmap. this fixes context lifetime problems mentioned in
https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313.

namely,

* wasi-nn resources will be freed earlier now. before this change,
  they used to be kept until the runtime shutdown. (wasm_runtime_destroy)
  after this change, they will be freed together with the associated
  instances.

* wasm_module_inst_t pointer uniqueness assumption (which is wrong
  after wasm_runtime_deinstantiate) was lifted.

as a side effect, this change also makes a context shared among threads
within a cluster. note that this is a user-visible api/abi breaking change.
before this change, wasi-nn "handles" like wasi_ephemeral_nn_graph were
thread-local. after this change, they are shared among threads within
a cluster, similarly to wasi file descriptors. spec-wise, either behavior
should be ok simply because wasi officially doesn't have threads yet.
althogh i feel the latter semantics is more intuitive, if your application
depends on the thread-local behavior, this change breaks your application.

tested with wamr-wasi-extensions/samples/nn-cli, modified to
call each wasi-nn operations on different threads. (if you are
interested, you can find the modification at
https://github.com/yamt/wasm-micro-runtime/tree/yamt-nn-wip-20250619.)

cf.
https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313
https://github.com/bytecodealliance/wasm-micro-runtime/issues/2430

* runtime_lib.cmake: enable WAMR_BUILD_MODULE_INST_CONTEXT for wasi-nn

as we do for wasi (WAMR_BUILD_LIBC_WASI)
YAMAMOTO Takashi há 8 meses atrás
pai
commit
70c39bae77
2 ficheiros alterados com 22 adições e 57 exclusões
  1. 1 0
      build-scripts/runtime_lib.cmake
  2. 21 57
      core/iwasm/libraries/wasi-nn/src/wasi_nn.c

+ 1 - 0
build-scripts/runtime_lib.cmake

@@ -106,6 +106,7 @@ endif ()
 
 if (WAMR_BUILD_WASI_NN EQUAL 1)
     include (${IWASM_DIR}/libraries/wasi-nn/cmake/wasi_nn.cmake)
+    set (WAMR_BUILD_MODULE_INST_CONTEXT 1)
 endif ()
 
 if (WAMR_BUILD_LIB_PTHREAD EQUAL 1)

+ 21 - 57
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -55,49 +55,15 @@ struct backends_api_functions {
             NN_ERR_PRINTF("Error %s() -> %d", #func, wasi_error);          \
     } while (0)
 
-/* HashMap utils */
-static HashMap *hashmap;
-
-static uint32
-hash_func(const void *key)
-{
-    // fnv1a_hash
-    const uint32 FNV_PRIME = 16777619;
-    const uint32 FNV_OFFSET_BASIS = 2166136261U;
-
-    uint32 hash = FNV_OFFSET_BASIS;
-    const unsigned char *bytes = (const unsigned char *)key;
-
-    for (size_t i = 0; i < sizeof(uintptr_t); ++i) {
-        hash ^= bytes[i];
-        hash *= FNV_PRIME;
-    }
-
-    return hash;
-}
-
-static bool
-key_equal_func(void *key1, void *key2)
-{
-    return key1 == key2;
-}
-
-static void
-key_destroy_func(void *key1)
-{
-    /* key type is wasm_module_inst_t*. do nothing */
-}
+static void *wasi_nn_key;
 
 static void
 wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
 {
-    NN_DBG_PRINTF("[WASI NN] DEINIT...");
-
     if (wasi_nn_ctx == NULL) {
-        NN_ERR_PRINTF(
-            "Error when deallocating memory. WASI-NN context is NULL");
         return;
     }
+    NN_DBG_PRINTF("[WASI NN] DEINIT...");
     NN_DBG_PRINTF("Freeing wasi-nn");
     NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
     NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
@@ -116,9 +82,9 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
 }
 
 static void
-value_destroy_func(void *value)
+dtor(wasm_module_inst_t inst, void *ctx)
 {
-    wasi_nn_ctx_destroy((WASINNContext *)value);
+    wasi_nn_ctx_destroy(ctx);
 }
 
 bool
@@ -131,12 +97,9 @@ wasi_nn_initialize()
         return false;
     }
 
-    // hashmap { instance: wasi_nn_ctx }
-    hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
-                                 key_equal_func, key_destroy_func,
-                                 value_destroy_func);
-    if (hashmap == NULL) {
-        NN_ERR_PRINTF("Error while initializing hashmap");
+    wasi_nn_key = wasm_runtime_create_context_key(dtor);
+    if (wasi_nn_key == NULL) {
+        NN_ERR_PRINTF("Failed to create context key");
         os_mutex_destroy(&wasi_nn_lock);
         return false;
     }
@@ -170,21 +133,23 @@ static WASINNContext *
 wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
 {
     WASINNContext *wasi_nn_ctx =
-        (WASINNContext *)bh_hash_map_find(hashmap, (void *)instance);
+        wasm_runtime_get_context(instance, wasi_nn_key);
     if (wasi_nn_ctx == NULL) {
-        wasi_nn_ctx = wasi_nn_initialize_context();
-        if (wasi_nn_ctx == NULL)
-            return NULL;
-
-        bool ok =
-            bh_hash_map_insert(hashmap, (void *)instance, (void *)wasi_nn_ctx);
-        if (!ok) {
-            NN_ERR_PRINTF("Error while storing context");
-            wasi_nn_ctx_destroy(wasi_nn_ctx);
+        WASINNContext *newctx = wasi_nn_initialize_context();
+        if (newctx == NULL)
             return NULL;
+        os_mutex_lock(&wasi_nn_lock);
+        wasi_nn_ctx = wasm_runtime_get_context(instance, wasi_nn_key);
+        if (wasi_nn_ctx == NULL) {
+            wasm_runtime_set_context_spread(instance, wasi_nn_key, newctx);
+            wasi_nn_ctx = newctx;
+            newctx = NULL;
+        }
+        os_mutex_unlock(&wasi_nn_lock);
+        if (newctx != NULL) {
+            wasi_nn_ctx_destroy(newctx);
         }
     }
-
     return wasi_nn_ctx;
 }
 
@@ -220,8 +185,7 @@ unlock_ctx(WASINNContext *wasi_nn_ctx)
 void
 wasi_nn_destroy()
 {
-    // destroy hashmap will destroy keys and values
-    bh_hash_map_destroy(hashmap);
+    wasm_runtime_destroy_context_key(wasi_nn_key);
 
     // close backends' libraries and registered functions
     for (unsigned i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {