|
|
@@ -102,6 +102,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
|
|
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
|
|
|
|
|
|
+ bh_assert(!wasi_nn_ctx->busy);
|
|
|
+
|
|
|
/* deinit() the backend */
|
|
|
if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
|
|
wasi_nn_error res;
|
|
|
@@ -109,6 +111,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
wasi_nn_ctx->backend_ctx);
|
|
|
}
|
|
|
|
|
|
+ os_mutex_destroy(&wasi_nn_ctx->lock);
|
|
|
wasm_runtime_free(wasi_nn_ctx);
|
|
|
}
|
|
|
|
|
|
@@ -154,6 +157,11 @@ wasi_nn_initialize_context()
|
|
|
}
|
|
|
|
|
|
memset(wasi_nn_ctx, 0, sizeof(WASINNContext));
|
|
|
+ if (os_mutex_init(&wasi_nn_ctx->lock)) {
|
|
|
+ NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context");
|
|
|
+ wasm_runtime_free(wasi_nn_ctx);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
return wasi_nn_ctx;
|
|
|
}
|
|
|
|
|
|
@@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
|
|
return wasi_nn_ctx;
|
|
|
}
|
|
|
|
|
|
+static WASINNContext *
|
|
|
+lock_ctx(wasm_module_inst_t instance)
|
|
|
+{
|
|
|
+ WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ os_mutex_lock(&wasi_nn_ctx->lock);
|
|
|
+ if (wasi_nn_ctx->busy) {
|
|
|
+ os_mutex_unlock(&wasi_nn_ctx->lock);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ wasi_nn_ctx->busy = true;
|
|
|
+ os_mutex_unlock(&wasi_nn_ctx->lock);
|
|
|
+ return wasi_nn_ctx;
|
|
|
+}
|
|
|
+
|
|
|
+static void
|
|
|
+unlock_ctx(WASINNContext *wasi_nn_ctx)
|
|
|
+{
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ os_mutex_lock(&wasi_nn_ctx->lock);
|
|
|
+ bh_assert(wasi_nn_ctx->busy);
|
|
|
+ wasi_nn_ctx->busy = false;
|
|
|
+ os_mutex_unlock(&wasi_nn_ctx->lock);
|
|
|
+}
|
|
|
+
|
|
|
void
|
|
|
wasi_nn_destroy()
|
|
|
{
|
|
|
@@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
|
|
|
|
|
|
static wasi_nn_error
|
|
|
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|
|
- WASINNContext **wasi_nn_ctx_ptr)
|
|
|
+ WASINNContext *wasi_nn_ctx)
|
|
|
{
|
|
|
wasi_nn_error res;
|
|
|
|
|
|
@@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|
|
goto fail;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
|
|
if (wasi_nn_ctx->backend != loaded_backend) {
|
|
|
res = unsupported_operation;
|
|
|
@@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|
|
|
|
|
wasi_nn_ctx->is_backend_ctx_initialized = true;
|
|
|
}
|
|
|
- *wasi_nn_ctx_ptr = wasi_nn_ctx;
|
|
|
return success;
|
|
|
fail:
|
|
|
return res;
|
|
|
@@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
if (!instance)
|
|
|
return runtime_error;
|
|
|
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
graph_builder_array builder_native = { 0 };
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
if (success
|
|
|
!= (res = graph_builder_array_app_native(
|
|
|
instance, builder, builder_wasm_size, &builder_native)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
|
|
if (success
|
|
|
!= (res = graph_builder_array_app_native(instance, builder,
|
|
|
&builder_native)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
|
|
|
if (!wasm_runtime_validate_native_addr(instance, g,
|
|
|
@@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
goto fail;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx;
|
|
|
- res = ensure_backend(instance, encoding, &wasi_nn_ctx);
|
|
|
+ res = ensure_backend(instance, encoding, wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
goto fail;
|
|
|
|
|
|
@@ -498,6 +538,7 @@ fail:
|
|
|
// XXX: Free intermediate structure pointers
|
|
|
if (builder_native.buf)
|
|
|
wasm_runtime_free(builder_native.buf);
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
|
|
|
return res;
|
|
|
}
|
|
|
@@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|
|
|
|
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx;
|
|
|
- res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
+ res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
|
|
|
wasi_nn_ctx->backend_ctx, name, name_len, g);
|
|
|
if (res != success)
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
wasi_nn_ctx->is_model_loaded = true;
|
|
|
- return success;
|
|
|
+ res = success;
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
wasi_nn_error
|
|
|
@@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
|
|
|
|
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx;
|
|
|
- res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
+ res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
|
|
if (res != success)
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
+ ;
|
|
|
|
|
|
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
|
|
|
wasi_nn_ctx->backend_ctx, name, name_len, config,
|
|
|
config_len, g);
|
|
|
if (res != success)
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
wasi_nn_ctx->is_model_loaded = true;
|
|
|
- return success;
|
|
|
+ res = success;
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
wasi_nn_error
|
|
|
@@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
-
|
|
|
wasi_nn_error res;
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
if (!wasm_runtime_validate_native_addr(
|
|
|
instance, ctx, (uint64)sizeof(graph_execution_context))) {
|
|
|
NN_ERR_PRINTF("ctx is invalid");
|
|
|
- return invalid_argument;
|
|
|
+ res = invalid_argument;
|
|
|
+ goto fail;
|
|
|
}
|
|
|
|
|
|
call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res,
|
|
|
wasi_nn_ctx->backend_ctx, g, ctx);
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
@@ -634,17 +699,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
-
|
|
|
wasi_nn_error res;
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
tensor input_tensor_native = { 0 };
|
|
|
if (success
|
|
|
!= (res = tensor_app_native(instance, input_tensor,
|
|
|
&input_tensor_native)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res,
|
|
|
wasi_nn_ctx->backend_ctx, ctx, index,
|
|
|
@@ -652,7 +721,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
// XXX: Free intermediate structure pointers
|
|
|
if (input_tensor_native.dimensions)
|
|
|
wasm_runtime_free(input_tensor_native.dimensions);
|
|
|
-
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
@@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
-
|
|
|
wasi_nn_error res;
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
call_wasi_nn_func(wasi_nn_ctx->backend, compute, res,
|
|
|
wasi_nn_ctx->backend_ctx, ctx);
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
@@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
-
|
|
|
wasi_nn_error res;
|
|
|
+ WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ res = busy;
|
|
|
+ goto fail;
|
|
|
+ }
|
|
|
+
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
- return res;
|
|
|
+ goto fail;
|
|
|
|
|
|
if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
|
|
|
(uint64)sizeof(uint32_t))) {
|
|
|
NN_ERR_PRINTF("output_tensor_size is invalid");
|
|
|
- return invalid_argument;
|
|
|
+ res = invalid_argument;
|
|
|
+ goto fail;
|
|
|
}
|
|
|
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
@@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
|
|
output_tensor_size);
|
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
+fail:
|
|
|
+ unlock_ctx(wasi_nn_ctx);
|
|
|
return res;
|
|
|
}
|
|
|
|