|
|
@@ -9,16 +9,18 @@
|
|
|
#include <assert.h>
|
|
|
#include <errno.h>
|
|
|
#include <string.h>
|
|
|
+#include <stdint.h>
|
|
|
|
|
|
#include "wasi_nn.h"
|
|
|
+#include "wasi_nn_private.h"
|
|
|
#include "wasi_nn_app_native.h"
|
|
|
-#include "logger.h"
|
|
|
#include "wasi_nn_tensorflowlite.hpp"
|
|
|
+#include "logger.h"
|
|
|
|
|
|
#include "bh_platform.h"
|
|
|
#include "wasm_export.h"
|
|
|
-#include "wasm_runtime.h"
|
|
|
-#include "aot_runtime.h"
|
|
|
+
|
|
|
+#define HASHMAP_INITIAL_SIZE 20
|
|
|
|
|
|
/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
|
|
|
|
|
|
@@ -51,6 +53,119 @@ static api_function lookup[] = {
|
|
|
tensorflowlite_get_output }
|
|
|
};
|
|
|
|
|
|
+static HashMap *hashmap;
|
|
|
+
|
|
|
+static void
|
|
|
+wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx);
|
|
|
+
|
|
|
+/* Get wasi-nn context from module instance */
|
|
|
+
|
|
|
+static uint32
|
|
|
+hash_func(const void *key)
|
|
|
+{
|
|
|
+ // fnv1a_hash
|
|
|
+ const uint32 FNV_PRIME = 16777619;
|
|
|
+ const uint32 FNV_OFFSET_BASIS = 2166136261U;
|
|
|
+
|
|
|
+ uint32 hash = FNV_OFFSET_BASIS;
|
|
|
+ const unsigned char *bytes = (const unsigned char *)key;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < sizeof(uintptr_t); ++i) {
|
|
|
+ hash ^= bytes[i];
|
|
|
+ hash *= FNV_PRIME;
|
|
|
+ }
|
|
|
+
|
|
|
+ return hash;
|
|
|
+}
|
|
|
+
|
|
|
+static bool
|
|
|
+key_equal_func(void *key1, void *key2)
|
|
|
+{
|
|
|
+ return key1 == key2;
|
|
|
+}
|
|
|
+
|
|
|
+static void
|
|
|
+key_destroy_func(void *key1)
|
|
|
+{}
|
|
|
+
|
|
|
+static void
|
|
|
+value_destroy_func(void *value)
|
|
|
+{
|
|
|
+ wasi_nn_ctx_destroy((WASINNContext *)value);
|
|
|
+}
|
|
|
+
|
|
|
+static WASINNContext *
|
|
|
+wasi_nn_initialize_context()
|
|
|
+{
|
|
|
+ NN_DBG_PRINTF("Initializing wasi-nn context");
|
|
|
+ WASINNContext *wasi_nn_ctx =
|
|
|
+ (WASINNContext *)wasm_runtime_malloc(sizeof(WASINNContext));
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ NN_ERR_PRINTF("Error when allocating memory for WASI-NN context");
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ wasi_nn_ctx->is_model_loaded = false;
|
|
|
+ tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
|
|
|
+ return wasi_nn_ctx;
|
|
|
+}
|
|
|
+
|
|
|
+static bool
|
|
|
+wasi_nn_initialize()
|
|
|
+{
|
|
|
+ NN_DBG_PRINTF("Initializing wasi-nn");
|
|
|
+ hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
|
|
|
+ key_equal_func, key_destroy_func,
|
|
|
+ value_destroy_func);
|
|
|
+ if (hashmap == NULL) {
|
|
|
+ NN_ERR_PRINTF("Error while initializing hashmap");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+static WASINNContext *
|
|
|
+wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
|
|
+{
|
|
|
+ WASINNContext *wasi_nn_ctx =
|
|
|
+ (WASINNContext *)bh_hash_map_find(hashmap, (void *)instance);
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ wasi_nn_ctx = wasi_nn_initialize_context();
|
|
|
+ if (wasi_nn_ctx == NULL)
|
|
|
+ return NULL;
|
|
|
+ bool ok =
|
|
|
+ bh_hash_map_insert(hashmap, (void *)instance, (void *)wasi_nn_ctx);
|
|
|
+ if (!ok) {
|
|
|
+ NN_ERR_PRINTF("Error while storing context");
|
|
|
+ wasi_nn_ctx_destroy(wasi_nn_ctx);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ NN_DBG_PRINTF("Returning ctx");
|
|
|
+ return wasi_nn_ctx;
|
|
|
+}
|
|
|
+
|
|
|
+static void
|
|
|
+wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
+{
|
|
|
+ if (wasi_nn_ctx == NULL) {
|
|
|
+ NN_ERR_PRINTF(
|
|
|
+ "Error when deallocating memory. WASI-NN context is NULL");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ NN_DBG_PRINTF("Freeing wasi-nn");
|
|
|
+ NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
|
|
+ NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
|
|
+ tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
|
|
|
+ wasm_runtime_free(wasi_nn_ctx);
|
|
|
+}
|
|
|
+
|
|
|
+void
|
|
|
+wasi_nn_destroy(wasm_module_inst_t instance)
|
|
|
+{
|
|
|
+ WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
|
+ wasi_nn_ctx_destroy(wasi_nn_ctx);
|
|
|
+}
|
|
|
+
|
|
|
/* Utils */
|
|
|
|
|
|
static bool
|
|
|
@@ -64,36 +179,13 @@ is_encoding_implemented(graph_encoding encoding)
|
|
|
static error
|
|
|
is_model_initialized(WASINNContext *wasi_nn_ctx)
|
|
|
{
|
|
|
- if (!wasi_nn_ctx->is_initialized) {
|
|
|
+ if (!wasi_nn_ctx->is_model_loaded) {
|
|
|
NN_ERR_PRINTF("Model not initialized.");
|
|
|
return runtime_error;
|
|
|
}
|
|
|
return success;
|
|
|
}
|
|
|
|
|
|
-WASINNContext *
|
|
|
-wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
|
|
-{
|
|
|
- WASINNContext *wasi_nn_ctx = NULL;
|
|
|
-#if WASM_ENABLE_INTERP != 0
|
|
|
- if (instance->module_type == Wasm_Module_Bytecode) {
|
|
|
- NN_DBG_PRINTF("Getting ctx from WASM");
|
|
|
- WASMModuleInstance *module_inst = (WASMModuleInstance *)instance;
|
|
|
- wasi_nn_ctx = ((WASMModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
|
|
|
- }
|
|
|
-#endif
|
|
|
-#if WASM_ENABLE_AOT != 0
|
|
|
- if (instance->module_type == Wasm_Module_AoT) {
|
|
|
- NN_DBG_PRINTF("Getting ctx from AOT");
|
|
|
- AOTModuleInstance *module_inst = (AOTModuleInstance *)instance;
|
|
|
- wasi_nn_ctx = ((AOTModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
|
|
|
- }
|
|
|
-#endif
|
|
|
- bh_assert(wasi_nn_ctx != NULL);
|
|
|
- NN_DBG_PRINTF("Returning ctx");
|
|
|
- return wasi_nn_ctx;
|
|
|
-}
|
|
|
-
|
|
|
/* WASI-NN implementation */
|
|
|
|
|
|
error
|
|
|
@@ -131,7 +223,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|
|
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
|
|
|
|
|
wasi_nn_ctx->current_encoding = encoding;
|
|
|
- wasi_nn_ctx->is_initialized = true;
|
|
|
+ wasi_nn_ctx->is_model_loaded = true;
|
|
|
|
|
|
fail:
|
|
|
// XXX: Free intermediate structure pointers
|
|
|
@@ -250,39 +342,6 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
-/* Non-exposed public functions */
|
|
|
-
|
|
|
-WASINNContext *
|
|
|
-wasi_nn_initialize()
|
|
|
-{
|
|
|
- NN_DBG_PRINTF("Initializing wasi-nn");
|
|
|
- WASINNContext *wasi_nn_ctx =
|
|
|
- (WASINNContext *)wasm_runtime_malloc(sizeof(WASINNContext));
|
|
|
- if (wasi_nn_ctx == NULL) {
|
|
|
- NN_ERR_PRINTF("Error when allocating memory for WASI-NN context");
|
|
|
- return NULL;
|
|
|
- }
|
|
|
- wasi_nn_ctx->is_initialized = true;
|
|
|
- wasi_nn_ctx->current_encoding = 3;
|
|
|
- tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
|
|
|
- return wasi_nn_ctx;
|
|
|
-}
|
|
|
-
|
|
|
-void
|
|
|
-wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
|
|
|
-{
|
|
|
- if (wasi_nn_ctx == NULL) {
|
|
|
- NN_ERR_PRINTF(
|
|
|
- "Error when deallocating memory. WASI-NN context is NULL");
|
|
|
- return;
|
|
|
- }
|
|
|
- NN_DBG_PRINTF("Freeing wasi-nn");
|
|
|
- NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
|
|
|
- NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
|
|
- tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
|
|
|
- wasm_runtime_free(wasi_nn_ctx);
|
|
|
-}
|
|
|
-
|
|
|
/* Register WASI-NN in WAMR */
|
|
|
|
|
|
/* clang-format off */
|
|
|
@@ -299,8 +358,19 @@ static NativeSymbol native_symbols_wasi_nn[] = {
|
|
|
};
|
|
|
|
|
|
uint32_t
|
|
|
-get_wasi_nn_export_apis(NativeSymbol **p_libc_wasi_apis)
|
|
|
+get_wasi_nn_export_apis(NativeSymbol **p_native_symbols)
|
|
|
{
|
|
|
- *p_libc_wasi_apis = native_symbols_wasi_nn;
|
|
|
+ if (!wasi_nn_initialize())
|
|
|
+ return 0;
|
|
|
+ *p_native_symbols = native_symbols_wasi_nn;
|
|
|
return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
|
|
|
}
|
|
|
+
|
|
|
+#if defined(WASI_NN_SHARED)
|
|
|
+uint32_t
|
|
|
+get_native_lib(char **p_module_name, NativeSymbol **p_native_symbols)
|
|
|
+{
|
|
|
+ *p_module_name = "wasi_nn";
|
|
|
+ return get_wasi_nn_export_apis(p_native_symbols);
|
|
|
+}
|
|
|
+#endif
|