|
@@ -16,12 +16,26 @@
|
|
|
#include "logger.h"
|
|
#include "logger.h"
|
|
|
|
|
|
|
|
#include "bh_platform.h"
|
|
#include "bh_platform.h"
|
|
|
|
|
+#include "wasi_nn_types.h"
|
|
|
#include "wasm_export.h"
|
|
#include "wasm_export.h"
|
|
|
|
|
|
|
|
#define HASHMAP_INITIAL_SIZE 20
|
|
#define HASHMAP_INITIAL_SIZE 20
|
|
|
|
|
|
|
|
/* Global variables */
|
|
/* Global variables */
|
|
|
-static api_function lookup[backend_amount] = { 0 };
|
|
|
|
|
|
|
+// if using `load_by_name`, there is no known `encoding` at the time of loading
|
|
|
|
|
+// so, just keep one `api_function` is enough
|
|
|
|
|
+static api_function lookup = { 0 };
|
|
|
|
|
+
|
|
|
|
|
+#define call_wasi_nn_func(wasi_error, func, ...) \
|
|
|
|
|
+ do { \
|
|
|
|
|
+ if (lookup.func) { \
|
|
|
|
|
+ wasi_error = lookup.func(__VA_ARGS__); \
|
|
|
|
|
+ } \
|
|
|
|
|
+ else { \
|
|
|
|
|
+ NN_ERR_PRINTF("Error: %s is not registered", #func); \
|
|
|
|
|
+ wasi_error = unsupported_operation; \
|
|
|
|
|
+ } \
|
|
|
|
|
+ } while (0)
|
|
|
|
|
|
|
|
static HashMap *hashmap;
|
|
static HashMap *hashmap;
|
|
|
|
|
|
|
@@ -73,16 +87,16 @@ wasi_nn_initialize_context()
|
|
|
return NULL;
|
|
return NULL;
|
|
|
}
|
|
}
|
|
|
wasi_nn_ctx->is_model_loaded = false;
|
|
wasi_nn_ctx->is_model_loaded = false;
|
|
|
|
|
+
|
|
|
/* only one backend can be registered */
|
|
/* only one backend can be registered */
|
|
|
- {
|
|
|
|
|
- unsigned i;
|
|
|
|
|
- for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
|
|
|
|
|
- if (lookup[i].init) {
|
|
|
|
|
- lookup[i].init(&wasi_nn_ctx->backend_ctx);
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ wasi_nn_error res;
|
|
|
|
|
+ call_wasi_nn_func(res, init, &wasi_nn_ctx->backend_ctx);
|
|
|
|
|
+ if (res != success) {
|
|
|
|
|
+ NN_ERR_PRINTF("Error while initializing backend");
|
|
|
|
|
+ wasm_runtime_free(wasi_nn_ctx);
|
|
|
|
|
+ return NULL;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
return wasi_nn_ctx;
|
|
return wasi_nn_ctx;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -90,6 +104,7 @@ static bool
|
|
|
wasi_nn_initialize()
|
|
wasi_nn_initialize()
|
|
|
{
|
|
{
|
|
|
NN_DBG_PRINTF("Initializing wasi-nn");
|
|
NN_DBG_PRINTF("Initializing wasi-nn");
|
|
|
|
|
+ // hashmap { instance: wasi_nn_ctx }
|
|
|
hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
|
|
hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
|
|
|
key_equal_func, key_destroy_func,
|
|
key_equal_func, key_destroy_func,
|
|
|
value_destroy_func);
|
|
value_destroy_func);
|
|
@@ -133,42 +148,26 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
NN_DBG_PRINTF("Freeing wasi-nn");
|
|
NN_DBG_PRINTF("Freeing wasi-nn");
|
|
|
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
|
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
|
|
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
|
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
|
|
|
|
+
|
|
|
/* only one backend can be registered */
|
|
/* only one backend can be registered */
|
|
|
- {
|
|
|
|
|
- unsigned i;
|
|
|
|
|
- for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
|
|
|
|
|
- if (lookup[i].deinit) {
|
|
|
|
|
- lookup[i].deinit(wasi_nn_ctx->backend_ctx);
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ wasi_nn_error res;
|
|
|
|
|
+ call_wasi_nn_func(res, deinit, wasi_nn_ctx->backend_ctx);
|
|
|
|
|
+ if (res != success) {
|
|
|
|
|
+ NN_ERR_PRINTF("Error while destroyging backend");
|
|
|
}
|
|
}
|
|
|
- wasm_runtime_free(wasi_nn_ctx);
|
|
|
|
|
-}
|
|
|
|
|
|
|
|
|
|
-static void
|
|
|
|
|
-wasi_nn_ctx_destroy_helper(void *instance, void *wasi_nn_ctx, void *user_data)
|
|
|
|
|
-{
|
|
|
|
|
- wasi_nn_ctx_destroy((WASINNContext *)wasi_nn_ctx);
|
|
|
|
|
|
|
+ wasm_runtime_free(wasi_nn_ctx);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void
|
|
void
|
|
|
wasi_nn_destroy()
|
|
wasi_nn_destroy()
|
|
|
{
|
|
{
|
|
|
- bh_hash_map_traverse(hashmap, wasi_nn_ctx_destroy_helper, NULL);
|
|
|
|
|
|
|
+ // destroy hashmap will destroy keys and values
|
|
|
bh_hash_map_destroy(hashmap);
|
|
bh_hash_map_destroy(hashmap);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/* Utils */
|
|
/* Utils */
|
|
|
|
|
|
|
|
-static bool
|
|
|
|
|
-is_encoding_implemented(graph_encoding encoding)
|
|
|
|
|
-{
|
|
|
|
|
- return lookup[encoding].load && lookup[encoding].init_execution_context
|
|
|
|
|
- && lookup[encoding].set_input && lookup[encoding].compute
|
|
|
|
|
- && lookup[encoding].get_output;
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
static wasi_nn_error
|
|
static wasi_nn_error
|
|
|
is_model_initialized(WASINNContext *wasi_nn_ctx)
|
|
is_model_initialized(WASINNContext *wasi_nn_ctx)
|
|
|
{
|
|
{
|
|
@@ -195,13 +194,9 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
|
|
NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
|
|
|
target);
|
|
target);
|
|
|
|
|
|
|
|
- if (!is_encoding_implemented(encoding)) {
|
|
|
|
|
- NN_ERR_PRINTF("Encoding not supported.");
|
|
|
|
|
- return invalid_encoding;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
- bh_assert(instance);
|
|
|
|
|
|
|
+ if (!instance)
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
|
|
|
wasi_nn_error res;
|
|
wasi_nn_error res;
|
|
|
graph_builder_array builder_native = { 0 };
|
|
graph_builder_array builder_native = { 0 };
|
|
@@ -225,10 +220,11 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- res = lookup[encoding].load(wasi_nn_ctx->backend_ctx, &builder_native,
|
|
|
|
|
- encoding, target, g);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, load, wasi_nn_ctx->backend_ctx, &builder_native,
|
|
|
|
|
+ encoding, target, g);
|
|
|
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
|
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
|
|
|
|
+ if (res != success)
|
|
|
|
|
+ goto fail;
|
|
|
|
|
|
|
|
wasi_nn_ctx->current_encoding = encoding;
|
|
wasi_nn_ctx->current_encoding = encoding;
|
|
|
wasi_nn_ctx->is_model_loaded = true;
|
|
wasi_nn_ctx->is_model_loaded = true;
|
|
@@ -241,6 +237,39 @@ fail:
|
|
|
return res;
|
|
return res;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+wasi_nn_error
|
|
|
|
|
+wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|
|
|
|
+ graph *g)
|
|
|
|
|
+{
|
|
|
|
|
+ NN_DBG_PRINTF("Running wasi_nn_load_by_name ...");
|
|
|
|
|
+
|
|
|
|
|
+ wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
|
|
+ if (!instance) {
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (!wasm_runtime_validate_native_addr(instance, name, name_len)) {
|
|
|
|
|
+ return invalid_argument;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (!wasm_runtime_validate_native_addr(instance, g,
|
|
|
|
|
+ (uint64)sizeof(graph))) {
|
|
|
|
|
+ return invalid_argument;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
|
|
+ wasi_nn_error res;
|
|
|
|
|
+ call_wasi_nn_func(res, load_by_name, wasi_nn_ctx->backend_ctx, name,
|
|
|
|
|
+ name_len, g);
|
|
|
|
|
+ NN_DBG_PRINTF("wasi_nn_load_by_name finished with status %d", *g);
|
|
|
|
|
+ if (res != success)
|
|
|
|
|
+ return res;
|
|
|
|
|
+
|
|
|
|
|
+ wasi_nn_ctx->current_encoding = autodetect;
|
|
|
|
|
+ wasi_nn_ctx->is_model_loaded = true;
|
|
|
|
|
+ return success;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
wasi_nn_error
|
|
wasi_nn_error
|
|
|
wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
|
graph_execution_context *ctx)
|
|
graph_execution_context *ctx)
|
|
@@ -248,7 +277,10 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
|
NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g);
|
|
NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g);
|
|
|
|
|
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
- bh_assert(instance);
|
|
|
|
|
|
|
+ if (!instance) {
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
|
|
|
|
|
wasi_nn_error res;
|
|
wasi_nn_error res;
|
|
@@ -261,9 +293,8 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
|
return invalid_argument;
|
|
return invalid_argument;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
|
|
|
|
|
- wasi_nn_ctx->backend_ctx, g, ctx);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, init_execution_context, wasi_nn_ctx->backend_ctx, g,
|
|
|
|
|
+ ctx);
|
|
|
NN_DBG_PRINTF(
|
|
NN_DBG_PRINTF(
|
|
|
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
|
|
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
|
|
|
*ctx);
|
|
*ctx);
|
|
@@ -278,7 +309,10 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
index);
|
|
index);
|
|
|
|
|
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
- bh_assert(instance);
|
|
|
|
|
|
|
+ if (!instance) {
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
|
|
|
|
|
wasi_nn_error res;
|
|
wasi_nn_error res;
|
|
@@ -291,9 +325,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
&input_tensor_native)))
|
|
&input_tensor_native)))
|
|
|
return res;
|
|
return res;
|
|
|
|
|
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].set_input(
|
|
|
|
|
- wasi_nn_ctx->backend_ctx, ctx, index, &input_tensor_native);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, set_input, wasi_nn_ctx->backend_ctx, ctx, index,
|
|
|
|
|
+ &input_tensor_native);
|
|
|
// XXX: Free intermediate structure pointers
|
|
// XXX: Free intermediate structure pointers
|
|
|
if (input_tensor_native.dimensions)
|
|
if (input_tensor_native.dimensions)
|
|
|
wasm_runtime_free(input_tensor_native.dimensions);
|
|
wasm_runtime_free(input_tensor_native.dimensions);
|
|
@@ -308,15 +341,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
|
|
NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
|
|
NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
|
|
|
|
|
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
- bh_assert(instance);
|
|
|
|
|
|
|
+ if (!instance) {
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
|
|
|
|
|
wasi_nn_error res;
|
|
wasi_nn_error res;
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
return res;
|
|
return res;
|
|
|
|
|
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].compute(
|
|
|
|
|
- wasi_nn_ctx->backend_ctx, ctx);
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, compute, wasi_nn_ctx->backend_ctx, ctx);
|
|
|
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
|
|
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
|
|
|
return res;
|
|
return res;
|
|
|
}
|
|
}
|
|
@@ -337,7 +372,10 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
index);
|
|
index);
|
|
|
|
|
|
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
|
|
- bh_assert(instance);
|
|
|
|
|
|
|
+ if (!instance) {
|
|
|
|
|
+ return runtime_error;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
|
|
|
|
|
wasi_nn_error res;
|
|
wasi_nn_error res;
|
|
@@ -351,14 +389,12 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
|
|
|
|
- wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
|
|
|
|
- &output_tensor_len);
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index,
|
|
|
|
|
+ output_tensor, &output_tensor_len);
|
|
|
*output_tensor_size = output_tensor_len;
|
|
*output_tensor_size = output_tensor_len;
|
|
|
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
|
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
|
|
|
|
- wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
|
|
|
|
- output_tensor_size);
|
|
|
|
|
|
|
+ call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index,
|
|
|
|
|
+ output_tensor, output_tensor_size);
|
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
|
|
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
|
|
|
res, *output_tensor_size);
|
|
res, *output_tensor_size);
|
|
@@ -375,6 +411,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
static NativeSymbol native_symbols_wasi_nn[] = {
|
|
static NativeSymbol native_symbols_wasi_nn[] = {
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
REG_NATIVE_FUNC(load, "(*iii*)i"),
|
|
REG_NATIVE_FUNC(load, "(*iii*)i"),
|
|
|
|
|
+ REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
|
|
|
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
|
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
|
|
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
|
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
|
|
REG_NATIVE_FUNC(compute, "(i)i"),
|
|
REG_NATIVE_FUNC(compute, "(i)i"),
|
|
@@ -429,15 +466,9 @@ deinit_native_lib()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
__attribute__((used)) bool
|
|
__attribute__((used)) bool
|
|
|
-wasi_nn_register_backend(graph_encoding backend_code, api_function apis)
|
|
|
|
|
|
|
+wasi_nn_register_backend(api_function apis)
|
|
|
{
|
|
{
|
|
|
NN_DBG_PRINTF("--|> wasi_nn_register_backend");
|
|
NN_DBG_PRINTF("--|> wasi_nn_register_backend");
|
|
|
-
|
|
|
|
|
- if (backend_code >= sizeof(lookup) / sizeof(lookup[0])) {
|
|
|
|
|
- NN_ERR_PRINTF("Invalid backend code");
|
|
|
|
|
- return false;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- lookup[backend_code] = apis;
|
|
|
|
|
|
|
+ lookup = apis;
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|