Przeglądaj źródła

wasi-nn: add minimum serialization on WASINNContext (#4387)

currently this is not necessary because context (WASINNContext) is
local to instance. (wasm_module_instance_t)

i plan to make a context shared among instances in a cluster when
fixing https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313.
this is a preparation for that direction.

an obvious alternative is to tweak the module instance context APIs
to allow declaring some kind of contexts instance-local. but i feel,
in this particular case, it's more natural to make "wasi-nn handles"
shared among threads within a "process".

note that, spec-wise, how wasi-nn behaves wrt threads is not defined
at all because wasi officially doesn't have threads yet. i suppose, at
this point, that how wasi-nn interacts with wasi-threads is something
we need to define by ourselves, especially when we are using an outdated
wasi-nn version.

with this change, if a thread attempts to access a context while
another thread is using it, we simply make the operation fail with
the "busy" error. this is intended for the mimimum serialization to
avoid problems like crashes/leaks/etc. this is not intended to allow
parallelism or such.

no functional changes are intended at this point yet.

cf.
https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313
https://github.com/bytecodealliance/wasm-micro-runtime/issues/2430
YAMAMOTO Takashi 8 miesięcy temu
rodzic
commit
ea408ab6c0

+ 116 - 33
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -102,6 +102,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
     NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
     NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
 
+    bh_assert(!wasi_nn_ctx->busy);
+
     /* deinit() the backend */
     if (wasi_nn_ctx->is_backend_ctx_initialized) {
         wasi_nn_error res;
@@ -109,6 +111,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
                           wasi_nn_ctx->backend_ctx);
     }
 
+    os_mutex_destroy(&wasi_nn_ctx->lock);
     wasm_runtime_free(wasi_nn_ctx);
 }
 
@@ -154,6 +157,11 @@ wasi_nn_initialize_context()
     }
 
     memset(wasi_nn_ctx, 0, sizeof(WASINNContext));
+    if (os_mutex_init(&wasi_nn_ctx->lock)) {
+        NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context");
+        wasm_runtime_free(wasi_nn_ctx);
+        return NULL;
+    }
     return wasi_nn_ctx;
 }
 
@@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
     return wasi_nn_ctx;
 }
 
+static WASINNContext *
+lock_ctx(wasm_module_inst_t instance)
+{
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        return NULL;
+    }
+    os_mutex_lock(&wasi_nn_ctx->lock);
+    if (wasi_nn_ctx->busy) {
+        os_mutex_unlock(&wasi_nn_ctx->lock);
+        return NULL;
+    }
+    wasi_nn_ctx->busy = true;
+    os_mutex_unlock(&wasi_nn_ctx->lock);
+    return wasi_nn_ctx;
+}
+
+static void
+unlock_ctx(WASINNContext *wasi_nn_ctx)
+{
+    if (wasi_nn_ctx == NULL) {
+        return;
+    }
+    os_mutex_lock(&wasi_nn_ctx->lock);
+    bh_assert(wasi_nn_ctx->busy);
+    wasi_nn_ctx->busy = false;
+    os_mutex_unlock(&wasi_nn_ctx->lock);
+}
+
 void
 wasi_nn_destroy()
 {
@@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
 
 static wasi_nn_error
 ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
-               WASINNContext **wasi_nn_ctx_ptr)
+               WASINNContext *wasi_nn_ctx)
 {
     wasi_nn_error res;
 
@@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
         goto fail;
     }
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
     if (wasi_nn_ctx->is_backend_ctx_initialized) {
         if (wasi_nn_ctx->backend != loaded_backend) {
             res = unsupported_operation;
@@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
 
         wasi_nn_ctx->is_backend_ctx_initialized = true;
     }
-    *wasi_nn_ctx_ptr = wasi_nn_ctx;
     return success;
 fail:
     return res;
@@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
     if (!instance)
         return runtime_error;
 
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
     graph_builder_array builder_native = { 0 };
 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
     if (success
         != (res = graph_builder_array_app_native(
                 instance, builder, builder_wasm_size, &builder_native)))
-        return res;
+        goto fail;
 #else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     if (success
         != (res = graph_builder_array_app_native(instance, builder,
                                                  &builder_native)))
-        return res;
+        goto fail;
 #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
     if (!wasm_runtime_validate_native_addr(instance, g,
@@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
         goto fail;
     }
 
-    WASINNContext *wasi_nn_ctx;
-    res = ensure_backend(instance, encoding, &wasi_nn_ctx);
+    res = ensure_backend(instance, encoding, wasi_nn_ctx);
     if (res != success)
         goto fail;
 
@@ -498,6 +538,7 @@ fail:
     // XXX: Free intermediate structure pointers
     if (builder_native.buf)
         wasm_runtime_free(builder_native.buf);
+    unlock_ctx(wasi_nn_ctx);
 
     return res;
 }
@@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
 
     NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
 
-    WASINNContext *wasi_nn_ctx;
-    res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
+    res = ensure_backend(instance, autodetect, wasi_nn_ctx);
     if (res != success)
-        return res;
+        goto fail;
 
     call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
                       wasi_nn_ctx->backend_ctx, name, name_len, g);
     if (res != success)
-        return res;
+        goto fail;
 
     wasi_nn_ctx->is_model_loaded = true;
-    return success;
+    res = success;
+fail:
+    unlock_ctx(wasi_nn_ctx);
+    return res;
 }
 
 wasi_nn_error
@@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
 
     NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
 
-    WASINNContext *wasi_nn_ctx;
-    res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
+    res = ensure_backend(instance, autodetect, wasi_nn_ctx);
     if (res != success)
-        return res;
+        goto fail;
+    ;
 
     call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
                       wasi_nn_ctx->backend_ctx, name, name_len, config,
                       config_len, g);
     if (res != success)
-        return res;
+        goto fail;
 
     wasi_nn_ctx->is_model_loaded = true;
-    return success;
+    res = success;
+fail:
+    unlock_ctx(wasi_nn_ctx);
+    return res;
 }
 
 wasi_nn_error
@@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
         return runtime_error;
     }
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-
     wasi_nn_error res;
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
     if (success != (res = is_model_initialized(wasi_nn_ctx)))
-        return res;
+        goto fail;
 
     if (!wasm_runtime_validate_native_addr(
             instance, ctx, (uint64)sizeof(graph_execution_context))) {
         NN_ERR_PRINTF("ctx is invalid");
-        return invalid_argument;
+        res = invalid_argument;
+        goto fail;
     }
 
     call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res,
                       wasi_nn_ctx->backend_ctx, g, ctx);
+fail:
+    unlock_ctx(wasi_nn_ctx);
     return res;
 }
 
@@ -634,17 +699,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
         return runtime_error;
     }
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-
     wasi_nn_error res;
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
     if (success != (res = is_model_initialized(wasi_nn_ctx)))
-        return res;
+        goto fail;
 
     tensor input_tensor_native = { 0 };
     if (success
         != (res = tensor_app_native(instance, input_tensor,
                                     &input_tensor_native)))
-        return res;
+        goto fail;
 
     call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res,
                       wasi_nn_ctx->backend_ctx, ctx, index,
@@ -652,7 +721,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
     // XXX: Free intermediate structure pointers
     if (input_tensor_native.dimensions)
         wasm_runtime_free(input_tensor_native.dimensions);
-
+fail:
+    unlock_ctx(wasi_nn_ctx);
     return res;
 }
 
@@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
         return runtime_error;
     }
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-
     wasi_nn_error res;
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
     if (success != (res = is_model_initialized(wasi_nn_ctx)))
-        return res;
+        goto fail;
 
     call_wasi_nn_func(wasi_nn_ctx->backend, compute, res,
                       wasi_nn_ctx->backend_ctx, ctx);
+fail:
+    unlock_ctx(wasi_nn_ctx);
     return res;
 }
 
@@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
         return runtime_error;
     }
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-
     wasi_nn_error res;
+    WASINNContext *wasi_nn_ctx = lock_ctx(instance);
+    if (wasi_nn_ctx == NULL) {
+        res = busy;
+        goto fail;
+    }
+
     if (success != (res = is_model_initialized(wasi_nn_ctx)))
-        return res;
+        goto fail;
 
     if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
                                            (uint64)sizeof(uint32_t))) {
         NN_ERR_PRINTF("output_tensor_size is invalid");
-        return invalid_argument;
+        res = invalid_argument;
+        goto fail;
     }
 
 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                       wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
                       output_tensor_size);
 #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
+fail:
+    unlock_ctx(wasi_nn_ctx);
     return res;
 }
 

+ 4 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

@@ -9,7 +9,11 @@
 #include "wasi_nn_types.h"
 #include "wasm_export.h"
 
+#include "bh_platform.h"
+
 typedef struct {
+    korp_mutex lock;
+    bool busy;
     bool is_backend_ctx_initialized;
     bool is_model_loaded;
     graph_encoding backend;