|
|
@@ -13,7 +13,6 @@
|
|
|
|
|
|
#include "wasi_nn_private.h"
|
|
|
#include "wasi_nn_app_native.h"
|
|
|
-#include "wasi_nn_tensorflowlite.hpp"
|
|
|
#include "logger.h"
|
|
|
|
|
|
#include "bh_platform.h"
|
|
|
@@ -21,45 +20,14 @@
|
|
|
|
|
|
#define HASHMAP_INITIAL_SIZE 20
|
|
|
|
|
|
-/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
|
|
|
-
|
|
|
-typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding,
|
|
|
- execution_target, graph *);
|
|
|
-typedef wasi_nn_error (*INIT_EXECUTION_CONTEXT)(void *, graph,
|
|
|
- graph_execution_context *);
|
|
|
-typedef wasi_nn_error (*SET_INPUT)(void *, graph_execution_context, uint32_t,
|
|
|
- tensor *);
|
|
|
-typedef wasi_nn_error (*COMPUTE)(void *, graph_execution_context);
|
|
|
-typedef wasi_nn_error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
|
|
|
- tensor_data, uint32_t *);
|
|
|
-
|
|
|
-typedef struct {
|
|
|
- LOAD load;
|
|
|
- INIT_EXECUTION_CONTEXT init_execution_context;
|
|
|
- SET_INPUT set_input;
|
|
|
- COMPUTE compute;
|
|
|
- GET_OUTPUT get_output;
|
|
|
-} api_function;
|
|
|
-
|
|
|
/* Global variables */
|
|
|
-
|
|
|
-static api_function lookup[] = {
|
|
|
- { NULL, NULL, NULL, NULL, NULL },
|
|
|
- { NULL, NULL, NULL, NULL, NULL },
|
|
|
- { NULL, NULL, NULL, NULL, NULL },
|
|
|
- { NULL, NULL, NULL, NULL, NULL },
|
|
|
- { tensorflowlite_load, tensorflowlite_init_execution_context,
|
|
|
- tensorflowlite_set_input, tensorflowlite_compute,
|
|
|
- tensorflowlite_get_output }
|
|
|
-};
|
|
|
+static api_function lookup[backend_amount] = { 0 };
|
|
|
|
|
|
static HashMap *hashmap;
|
|
|
|
|
|
static void
|
|
|
wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx);
|
|
|
|
|
|
-/* Get wasi-nn context from module instance */
|
|
|
-
|
|
|
static uint32
|
|
|
hash_func(const void *key)
|
|
|
{
|
|
|
@@ -105,7 +73,16 @@ wasi_nn_initialize_context()
|
|
|
return NULL;
|
|
|
}
|
|
|
wasi_nn_ctx->is_model_loaded = false;
|
|
|
- tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
|
|
|
+ /* only one backend can be registered */
|
|
|
+ {
|
|
|
+ unsigned i;
|
|
|
+ for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
|
|
|
+ if (lookup[i].init) {
|
|
|
+ lookup[i].init(&wasi_nn_ctx->backend_ctx);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
return wasi_nn_ctx;
|
|
|
}
|
|
|
|
|
|
@@ -123,6 +100,7 @@ wasi_nn_initialize()
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+/* Get wasi-nn context from module instance */
|
|
|
static WASINNContext *
|
|
|
wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
|
|
{
|
|
|
@@ -155,16 +133,30 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
NN_DBG_PRINTF("Freeing wasi-nn");
|
|
|
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
|
|
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
|
|
- tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
|
|
|
+ /* only one backend can be registered */
|
|
|
+ {
|
|
|
+ unsigned i;
|
|
|
+ for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
|
|
|
+ if (lookup[i].deinit) {
|
|
|
+ lookup[i].deinit(wasi_nn_ctx->backend_ctx);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
wasm_runtime_free(wasi_nn_ctx);
|
|
|
}
|
|
|
|
|
|
+static void
|
|
|
+wasi_nn_ctx_destroy_helper(void *instance, void *wasi_nn_ctx, void *user_data)
|
|
|
+{
|
|
|
+ wasi_nn_ctx_destroy((WASINNContext *)wasi_nn_ctx);
|
|
|
+}
|
|
|
+
|
|
|
void
|
|
|
-wasi_nn_destroy(wasm_module_inst_t instance)
|
|
|
+wasi_nn_destroy()
|
|
|
{
|
|
|
- WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- bh_hash_map_remove(hashmap, (void *)instance, NULL, NULL);
|
|
|
- wasi_nn_ctx_destroy(wasi_nn_ctx);
|
|
|
+ bh_hash_map_traverse(hashmap, wasi_nn_ctx_destroy_helper, NULL);
|
|
|
+ bh_hash_map_destroy(hashmap);
|
|
|
}
|
|
|
|
|
|
/* Utils */
|
|
|
@@ -233,7 +225,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
}
|
|
|
|
|
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
- res = lookup[encoding].load(wasi_nn_ctx->tflite_ctx, &builder_native,
|
|
|
+ res = lookup[encoding].load(wasi_nn_ctx->backend_ctx, &builder_native,
|
|
|
encoding, target, g);
|
|
|
|
|
|
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
|
|
@@ -270,7 +262,7 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|
|
}
|
|
|
|
|
|
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
|
|
|
- wasi_nn_ctx->tflite_ctx, g, ctx);
|
|
|
+ wasi_nn_ctx->backend_ctx, g, ctx);
|
|
|
|
|
|
NN_DBG_PRINTF(
|
|
|
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
|
|
|
@@ -300,7 +292,7 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
return res;
|
|
|
|
|
|
res = lookup[wasi_nn_ctx->current_encoding].set_input(
|
|
|
- wasi_nn_ctx->tflite_ctx, ctx, index, &input_tensor_native);
|
|
|
+ wasi_nn_ctx->backend_ctx, ctx, index, &input_tensor_native);
|
|
|
|
|
|
// XXX: Free intermediate structure pointers
|
|
|
if (input_tensor_native.dimensions)
|
|
|
@@ -323,8 +315,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
|
|
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
|
|
return res;
|
|
|
|
|
|
- res = lookup[wasi_nn_ctx->current_encoding].compute(wasi_nn_ctx->tflite_ctx,
|
|
|
- ctx);
|
|
|
+ res = lookup[wasi_nn_ctx->current_encoding].compute(
|
|
|
+ wasi_nn_ctx->backend_ctx, ctx);
|
|
|
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
|
|
|
return res;
|
|
|
}
|
|
|
@@ -360,11 +352,13 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
|
|
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
|
|
- wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, &output_tensor_len);
|
|
|
+ wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
|
|
+ &output_tensor_len);
|
|
|
*output_tensor_size = output_tensor_len;
|
|
|
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
|
|
res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
|
|
- wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
|
|
|
+ wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
|
|
+ output_tensor_size);
|
|
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
|
|
|
res, *output_tensor_size);
|
|
|
@@ -397,17 +391,53 @@ static NativeSymbol native_symbols_wasi_nn[] = {
|
|
|
uint32_t
|
|
|
get_wasi_nn_export_apis(NativeSymbol **p_native_symbols)
|
|
|
{
|
|
|
- if (!wasi_nn_initialize())
|
|
|
- return 0;
|
|
|
*p_native_symbols = native_symbols_wasi_nn;
|
|
|
return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
|
|
|
}
|
|
|
|
|
|
-#if defined(WASI_NN_SHARED)
|
|
|
-uint32_t
|
|
|
+__attribute__((used)) uint32_t
|
|
|
get_native_lib(char **p_module_name, NativeSymbol **p_native_symbols)
|
|
|
{
|
|
|
+ NN_DBG_PRINTF("--|> get_native_lib");
|
|
|
+
|
|
|
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
|
+ *p_module_name = "wasi_ephemeral_nn";
|
|
|
+#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
|
|
*p_module_name = "wasi_nn";
|
|
|
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
|
|
+
|
|
|
return get_wasi_nn_export_apis(p_native_symbols);
|
|
|
}
|
|
|
-#endif
|
|
|
+
|
|
|
+__attribute__((used)) int
|
|
|
+init_native_lib()
|
|
|
+{
|
|
|
+ NN_DBG_PRINTF("--|> init_native_lib");
|
|
|
+
|
|
|
+ if (!wasi_nn_initialize())
|
|
|
+ return 1;
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+__attribute__((used)) void
|
|
|
+deinit_native_lib()
|
|
|
+{
|
|
|
+ NN_DBG_PRINTF("--|> deinit_native_lib");
|
|
|
+
|
|
|
+ wasi_nn_destroy();
|
|
|
+}
|
|
|
+
|
|
|
+__attribute__((used)) bool
|
|
|
+wasi_nn_register_backend(graph_encoding backend_code, api_function apis)
|
|
|
+{
|
|
|
+ NN_DBG_PRINTF("--|> wasi_nn_register_backend");
|
|
|
+
|
|
|
+ if (backend_code >= sizeof(lookup) / sizeof(lookup[0])) {
|
|
|
+ NN_ERR_PRINTF("Invalid backend code");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ lookup[backend_code] = apis;
|
|
|
+ return true;
|
|
|
+}
|