Przeglądaj źródła

Refactor WASI-NN to simplify the support for multiple frameworks (#1834)

- Reorganize the library structure
- Use the latest version of `wasi-nn` wit (Oct 25, 2022):
    https://github.com/WebAssembly/wasi-nn/blob/0f77c48ec195748990ff67928a4b3eef5f16c2de/wasi-nn.wit.md
- Split logic that converts WASM structs to native structs in a separate file
- Simplify addition of new frameworks
tonibofarull 3 lat temu
rodzic
commit
9eed6686df

+ 7 - 3
build-scripts/runtime_lib.cmake

@@ -96,9 +96,13 @@ if (WAMR_BUILD_LIB_PTHREAD_SEMAPHORE EQUAL 1)
 endif ()
 
 if (WAMR_BUILD_WASI_NN EQUAL 1)
-    execute_process(COMMAND ${WAMR_ROOT_DIR}/core/deps/install_tensorflow.sh
-                    RESULT_VARIABLE TENSORFLOW_RESULT
-    )
+    if (NOT EXISTS "${WAMR_ROOT_DIR}/core/deps/tensorflow-src")
+        execute_process(COMMAND ${WAMR_ROOT_DIR}/core/deps/install_tensorflow.sh
+                        RESULT_VARIABLE TENSORFLOW_RESULT
+        )
+    else ()
+        message("Tensorflow is already downloaded.")
+    endif()
     set(TENSORFLOW_SOURCE_DIR "${WAMR_ROOT_DIR}/core/deps/tensorflow-src")
     include_directories (${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/include)
     include_directories (${TENSORFLOW_SOURCE_DIR})

+ 20 - 0
core/iwasm/aot/aot_runtime.c

@@ -1083,6 +1083,17 @@ aot_instantiate(AOTModule *module, bool is_sub_inst, uint32 stack_size,
     }
 #endif
 
+#if WASM_ENABLE_WASI_NN != 0
+    if (!is_sub_inst) {
+        if (!(((AOTModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx =
+                  wasi_nn_initialize())) {
+            set_error_buf(error_buf, error_buf_size,
+                          "wasi nn initialization failed");
+            goto fail;
+        }
+    }
+#endif
+
     /* Initialize the thread related data */
     if (stack_size == 0)
         stack_size = DEFAULT_WASM_STACK_SIZE;
@@ -1194,6 +1205,15 @@ aot_deinstantiate(AOTModuleInstance *module_inst, bool is_sub_inst)
         wasm_runtime_free(
             ((AOTModuleInstanceExtra *)module_inst->e)->c_api_func_imports);
 
+#if WASM_ENABLE_WASI_NN != 0
+    if (!is_sub_inst) {
+        WASINNContext *wasi_nn_ctx =
+            ((AOTModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
+        if (wasi_nn_ctx)
+            wasi_nn_destroy(wasi_nn_ctx);
+    }
+#endif
+
     wasm_runtime_free(module_inst);
 }
 

+ 7 - 0
core/iwasm/aot/aot_runtime.h

@@ -11,6 +11,10 @@
 #include "../interpreter/wasm_runtime.h"
 #include "../compilation/aot.h"
 
+#if WASM_ENABLE_WASI_NN != 0
+#include "../libraries/wasi-nn/src/wasi_nn_private.h"
+#endif
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -75,6 +79,9 @@ typedef struct AOTFunctionInstance {
 
 typedef struct AOTModuleInstanceExtra {
     CApiFuncImport *c_api_func_imports;
+#if WASM_ENABLE_WASI_NN != 0
+    WASINNContext *wasi_nn_ctx;
+#endif
 } AOTModuleInstanceExtra;
 
 #if defined(OS_ENABLE_HW_BOUND_CHECK) && defined(BH_PLATFORM_WINDOWS)

+ 18 - 0
core/iwasm/interpreter/wasm_runtime.c

@@ -1803,6 +1803,16 @@ wasm_instantiate(WASMModule *module, bool is_sub_inst, uint32 stack_size,
     }
 #endif
 
+#if WASM_ENABLE_WASI_NN != 0
+    if (!is_sub_inst) {
+        if (!(module_inst->e->wasi_nn_ctx = wasi_nn_initialize())) {
+            set_error_buf(error_buf, error_buf_size,
+                          "wasi nn initialization failed");
+            goto fail;
+        }
+    }
+#endif
+
 #if WASM_ENABLE_DEBUG_INTERP != 0                         \
     || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
         && WASM_ENABLE_LAZY_JIT != 0)
@@ -1984,6 +1994,14 @@ wasm_deinstantiate(WASMModuleInstance *module_inst, bool is_sub_inst)
     if (module_inst->e->c_api_func_imports)
         wasm_runtime_free(module_inst->e->c_api_func_imports);
 
+#if WASM_ENABLE_WASI_NN != 0
+    if (!is_sub_inst) {
+        WASINNContext *wasi_nn_ctx = module_inst->e->wasi_nn_ctx;
+        if (wasi_nn_ctx)
+            wasi_nn_destroy(wasi_nn_ctx);
+    }
+#endif
+
     wasm_runtime_free(module_inst);
 }
 

+ 8 - 0
core/iwasm/interpreter/wasm_runtime.h

@@ -11,6 +11,10 @@
 #include "../common/wasm_runtime_common.h"
 #include "../common/wasm_exec_env.h"
 
+#if WASM_ENABLE_WASI_NN != 0
+#include "../libraries/wasi-nn/src/wasi_nn_private.h"
+#endif
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -242,6 +246,10 @@ typedef struct WASMModuleInstanceExtra {
         && WASM_ENABLE_LAZY_JIT != 0)
     WASMModuleInstance *next;
 #endif
+
+#if WASM_ENABLE_WASI_NN != 0
+    WASINNContext *wasi_nn_ctx;
+#endif
 } WASMModuleInstanceExtra;
 
 struct AOTFuncPerfProfInfo;

+ 0 - 1
core/iwasm/libraries/wasi-nn/.dockerignore

@@ -1 +0,0 @@
-**/Dockerfile

+ 7 - 3
core/iwasm/libraries/wasi-nn/README.md

@@ -37,7 +37,11 @@ Tests: passed!
 
 ## What is missing
 
-* Only 1 model at a time is supported.
+Supported:
+
+* Only 1 WASM app at a time.
+* Only 1 model at a time.
     * `graph` and `graph-execution-context` are ignored.
-* Only `tensorflow` (lite) is supported.
-* Only `cpu` is supported.
+* Graph encoding: `tensorflowlite`.
+* Execution target: `cpu`.
+* Tensor type: `fp32`.

+ 0 - 55
core/iwasm/libraries/wasi-nn/logger.h

@@ -1,55 +0,0 @@
-/*
- * Copyright (C) 2019 Intel Corporation.  All rights reserved.
- * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- */
-
-#ifndef WASI_NN_LOGGER_H
-#define WASI_NN_LOGGER_H
-
-#include <stdio.h>
-#include <string.h>
-
-#define __FILENAME__ \
-    (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
-
-/* Disable a level by removing the define */
-#define ENABLE_ERR_LOG
-#define ENABLE_WARN_LOG
-#define ENABLE_DBG_LOG
-#define ENABLE_INFO_LOG
-
-// Definition of the levels
-#ifdef ENABLE_ERR_LOG
-#define NN_ERR_PRINTF(fmt, ...)                                    \
-    printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-    printf("\n");                                                  \
-    fflush(stdout)
-#else
-#define NN_ERR_PRINTF(fmt, ...)
-#endif
-#ifdef ENABLE_WARN_LOG
-#define NN_WARN_PRINTF(fmt, ...)                                   \
-    printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-    printf("\n");                                                  \
-    fflush(stdout)
-#else
-#define NN_WARN_PRINTF(fmt, ...)
-#endif
-#ifdef ENABLE_DBG_LOG
-#define NN_DBG_PRINTF(fmt, ...)                                    \
-    printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-    printf("\n");                                                  \
-    fflush(stdout)
-#else
-#define NN_DBG_PRINTF(fmt, ...)
-#endif
-#ifdef ENABLE_INFO_LOG
-#define NN_INFO_PRINTF(fmt, ...)                                   \
-    printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-    printf("\n");                                                  \
-    fflush(stdout)
-#else
-#define NN_INFO_PRINTF(fmt, ...)
-#endif
-
-#endif

+ 63 - 0
core/iwasm/libraries/wasi-nn/src/utils/logger.h

@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_LOGGER_H
+#define WASI_NN_LOGGER_H
+
+#include <stdio.h>
+#include <string.h>
+
+#define __FILENAME__ \
+    (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
+
+/* Disable a level by removing the define */
+#define ENABLE_ERR_LOG
+#define ENABLE_WARN_LOG
+#define ENABLE_DBG_LOG
+#define ENABLE_INFO_LOG
+
+// Definition of the levels
+#ifdef ENABLE_ERR_LOG
+#define NN_ERR_PRINTF(fmt, ...)                                        \
+    do {                                                               \
+        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                  \
+        fflush(stdout);                                                \
+    } while (0)
+#else
+#define NN_ERR_PRINTF(fmt, ...)
+#endif
+#ifdef ENABLE_WARN_LOG
+#define NN_WARN_PRINTF(fmt, ...)                                       \
+    do {                                                               \
+        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                  \
+        fflush(stdout);                                                \
+    } while (0)
+#else
+#define NN_WARN_PRINTF(fmt, ...)
+#endif
+#ifdef ENABLE_DBG_LOG
+#define NN_DBG_PRINTF(fmt, ...)                                        \
+    do {                                                               \
+        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                  \
+        fflush(stdout);                                                \
+    } while (0)
+#else
+#define NN_DBG_PRINTF(fmt, ...)
+#endif
+#ifdef ENABLE_INFO_LOG
+#define NN_INFO_PRINTF(fmt, ...)                                       \
+    do {                                                               \
+        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                  \
+        fflush(stdout);                                                \
+    } while (0)
+#else
+#define NN_INFO_PRINTF(fmt, ...)
+#endif
+
+#endif

+ 163 - 0
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c

@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#include "wasi_nn_app_native.h"
+
+static error
+graph_builder_app_native(wasm_module_inst_t instance,
+                         graph_builder_wasm *builder_wasm,
+                         graph_builder *builder)
+{
+    if (!wasm_runtime_validate_app_addr(instance, builder_wasm->buf_offset,
+                                        builder_wasm->size * sizeof(uint8_t))) {
+        NN_ERR_PRINTF("builder_wasm->buf_offset is invalid");
+        return invalid_argument;
+    }
+
+    builder->buf = (uint8_t *)wasm_runtime_addr_app_to_native(
+        instance, builder_wasm->buf_offset);
+    builder->size = builder_wasm->size;
+    return success;
+}
+
+error
+graph_builder_array_app_native(wasm_module_inst_t instance,
+                               graph_builder_array_wasm *builder_array_wasm,
+                               graph_builder_array *builder_array)
+{
+    if (!wasm_runtime_validate_native_addr(instance, builder_array_wasm,
+                                           sizeof(graph_builder_array_wasm))) {
+        NN_ERR_PRINTF("builder_array_wasm is invalid");
+        return invalid_argument;
+    }
+
+    NN_DBG_PRINTF("Graph builder array contains %d elements",
+                  builder_array_wasm->size);
+
+    if (!wasm_runtime_validate_app_addr(
+            instance, builder_array_wasm->buf_offset,
+            builder_array_wasm->size * sizeof(graph_builder_wasm))) {
+        NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid");
+        return invalid_argument;
+    }
+
+    graph_builder_wasm *builder_wasm =
+        (graph_builder_wasm *)wasm_runtime_addr_app_to_native(
+            instance, builder_array_wasm->buf_offset);
+
+    graph_builder *builder = (graph_builder *)wasm_runtime_malloc(
+        builder_array_wasm->size * sizeof(graph_builder));
+    if (builder == NULL)
+        return missing_memory;
+
+    for (uint32_t i = 0; i < builder_array_wasm->size; ++i) {
+        error res;
+        if (success
+            != (res = graph_builder_app_native(instance, &builder_wasm[i],
+                                               &builder[i]))) {
+            wasm_runtime_free(builder);
+            return res;
+        }
+
+        NN_DBG_PRINTF("Graph builder %d contains %d elements", i,
+                      builder->size);
+    }
+
+    builder_array->buf = builder;
+    builder_array->size = builder_array_wasm->size;
+    return success;
+}
+
+static error
+tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
+                       tensor_wasm *input_tensor_wasm, tensor_data *data)
+{
+    if (!wasm_runtime_validate_app_addr(
+            instance, input_tensor_wasm->data_offset, total_elements)) {
+        NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
+        return invalid_argument;
+    }
+    *data = (tensor_data)wasm_runtime_addr_app_to_native(
+        instance, input_tensor_wasm->data_offset);
+    return success;
+}
+
+static error
+tensor_dimensions_app_native(wasm_module_inst_t instance,
+                             tensor_wasm *input_tensor_wasm,
+                             tensor_dimensions **dimensions)
+{
+    if (!wasm_runtime_validate_app_addr(instance,
+                                        input_tensor_wasm->dimensions_offset,
+                                        sizeof(tensor_dimensions_wasm))) {
+        NN_ERR_PRINTF("input_tensor_wasm->dimensions_offset is invalid");
+        return invalid_argument;
+    }
+
+    tensor_dimensions_wasm *dimensions_wasm =
+        (tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native(
+            instance, input_tensor_wasm->dimensions_offset);
+
+    if (!wasm_runtime_validate_app_addr(instance, dimensions_wasm->buf_offset,
+                                        sizeof(tensor_dimensions))) {
+        NN_ERR_PRINTF("dimensions_wasm->buf_offset is invalid");
+        return invalid_argument;
+    }
+
+    *dimensions =
+        (tensor_dimensions *)wasm_runtime_malloc(sizeof(tensor_dimensions));
+    if (dimensions == NULL)
+        return missing_memory;
+
+    (*dimensions)->size = dimensions_wasm->size;
+    (*dimensions)->buf = (uint32_t *)wasm_runtime_addr_app_to_native(
+        instance, dimensions_wasm->buf_offset);
+
+    NN_DBG_PRINTF("Number of dimensions: %d", (*dimensions)->size);
+    return success;
+}
+
+error
+tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor_wasm,
+                  tensor *input_tensor)
+{
+    NN_DBG_PRINTF("Converting tensor_wasm to tensor");
+    if (!wasm_runtime_validate_native_addr(instance, input_tensor_wasm,
+                                           sizeof(tensor_wasm))) {
+        NN_ERR_PRINTF("input_tensor_wasm is invalid");
+        return invalid_argument;
+    }
+
+    error res;
+
+    tensor_dimensions *dimensions = NULL;
+    if (success
+        != (res = tensor_dimensions_app_native(instance, input_tensor_wasm,
+                                               &dimensions))) {
+        NN_ERR_PRINTF("error when parsing dimensions");
+        return res;
+    }
+
+    uint32_t total_elements = 1;
+    for (uint32_t i = 0; i < dimensions->size; ++i) {
+        total_elements *= dimensions->buf[i];
+        NN_DBG_PRINTF("Dimension %d: %d", i, dimensions->buf[i]);
+    }
+    NN_DBG_PRINTF("Tensor type: %d", input_tensor_wasm->type);
+    NN_DBG_PRINTF("Total number of elements: %d", total_elements);
+
+    tensor_data data = NULL;
+    if (success
+        != (res = tensor_data_app_native(instance, total_elements,
+                                         input_tensor_wasm, &data))) {
+        wasm_runtime_free(dimensions);
+        return res;
+    }
+
+    input_tensor->type = input_tensor_wasm->type;
+    input_tensor->dimensions = dimensions;
+    input_tensor->data = data;
+    return success;
+}

+ 51 - 0
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h

@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_APP_NATIVE
+#define WASI_NN_APP_NATIVE
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+#include <errno.h>
+#include <string.h>
+
+#include "wasi_nn.h"
+#include "logger.h"
+
+#include "bh_platform.h"
+#include "wasm_export.h"
+
+typedef struct {
+    uint32_t buf_offset;
+    uint32_t size;
+} graph_builder_wasm;
+
+typedef struct {
+    uint32_t buf_offset;
+    uint32_t size;
+} graph_builder_array_wasm;
+
+typedef struct {
+    uint32_t buf_offset;
+    uint32_t size;
+} tensor_dimensions_wasm;
+
+typedef struct {
+    uint32_t dimensions_offset;
+    tensor_type type;
+    uint32_t data_offset;
+} tensor_wasm;
+
+error
+graph_builder_array_app_native(wasm_module_inst_t instance,
+                               graph_builder_array_wasm *builder,
+                               graph_builder_array *builder_native);
+
+error
+tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor,
+                  tensor *input_tensor_native);
+
+#endif

+ 302 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -0,0 +1,302 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <assert.h>
+#include <errno.h>
+#include <string.h>
+
+#include "wasi_nn.h"
+#include "wasi_nn_app_native.h"
+#include "logger.h"
+#include "wasi_nn_tensorflowlite.hpp"
+
+#include "bh_platform.h"
+#include "wasm_export.h"
+#include "wasm_runtime.h"
+#include "aot_runtime.h"
+
+/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
+
+typedef error (*LOAD)(graph_builder_array *, graph_encoding, execution_target,
+                      graph *);
+typedef error (*INIT_EXECUTION_CONTEXT)(graph, graph_execution_context *);
+typedef error (*SET_INPUT)(graph_execution_context, uint32_t, tensor *);
+typedef error (*COMPUTE)(graph_execution_context);
+typedef error (*GET_OUTPUT)(graph_execution_context, uint32_t, tensor_data,
+                            uint32_t *);
+
+typedef struct {
+    LOAD load;
+    INIT_EXECUTION_CONTEXT init_execution_context;
+    SET_INPUT set_input;
+    COMPUTE compute;
+    GET_OUTPUT get_output;
+} api_function;
+
+/* Global variables */
+
+static api_function lookup[] = {
+    { NULL, NULL, NULL, NULL, NULL },
+    { NULL, NULL, NULL, NULL, NULL },
+    { NULL, NULL, NULL, NULL, NULL },
+    { NULL, NULL, NULL, NULL, NULL },
+    { tensorflowlite_load, tensorflowlite_init_execution_context,
+      tensorflowlite_set_input, tensorflowlite_compute,
+      tensorflowlite_get_output }
+};
+
+/* Utils */
+
+static bool
+is_encoding_implemented(graph_encoding encoding)
+{
+    return lookup[encoding].load && lookup[encoding].init_execution_context
+           && lookup[encoding].set_input && lookup[encoding].compute
+           && lookup[encoding].get_output;
+}
+
+static error
+is_model_initialized(WASINNContext *wasi_nn_ctx)
+{
+    if (!wasi_nn_ctx->is_initialized) {
+        NN_ERR_PRINTF("Model not initialized.");
+        return runtime_error;
+    }
+    return success;
+}
+
+WASINNContext *
+wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
+{
+    WASINNContext *wasi_nn_ctx = NULL;
+#if WASM_ENABLE_INTERP != 0
+    if (instance->module_type == Wasm_Module_Bytecode) {
+        NN_DBG_PRINTF("Getting ctx from WASM");
+        WASMModuleInstance *module_inst = (WASMModuleInstance *)instance;
+        wasi_nn_ctx = ((WASMModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
+    }
+#endif
+#if WASM_ENABLE_AOT != 0
+    if (instance->module_type == Wasm_Module_AoT) {
+        NN_DBG_PRINTF("Getting ctx from AOT");
+        AOTModuleInstance *module_inst = (AOTModuleInstance *)instance;
+        wasi_nn_ctx = ((AOTModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
+    }
+#endif
+    bh_assert(wasi_nn_ctx != NULL);
+    NN_DBG_PRINTF("Returning ctx");
+    return wasi_nn_ctx;
+}
+
+/* WASI-NN implementation */
+
+error
+wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
+             graph_encoding encoding, execution_target target, graph *g)
+{
+    NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
+                  target);
+
+    if (!is_encoding_implemented(encoding)) {
+        NN_ERR_PRINTF("Encoding not supported.");
+        return invalid_encoding;
+    }
+
+    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
+    bh_assert(instance);
+
+    error res;
+    graph_builder_array builder_native = { 0 };
+    if (success
+        != (res = graph_builder_array_app_native(instance, builder,
+                                                 &builder_native)))
+        return res;
+
+    if (!wasm_runtime_validate_native_addr(instance, g, sizeof(graph))) {
+        NN_ERR_PRINTF("graph is invalid");
+        res = invalid_argument;
+        goto fail;
+    }
+
+    res = lookup[encoding].load(&builder_native, encoding, target, g);
+
+    NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
+
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+
+    wasi_nn_ctx->current_encoding = encoding;
+    wasi_nn_ctx->is_initialized = true;
+
+fail:
+    // XXX: Free intermediate structure pointers
+    if (builder_native.buf)
+        wasm_runtime_free(builder_native.buf);
+
+    return res;
+}
+
+error
+wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
+                               graph_execution_context *ctx)
+{
+    NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g);
+
+    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
+    bh_assert(instance);
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+
+    error res;
+    if (success != (res = is_model_initialized(wasi_nn_ctx)))
+        return res;
+
+    if (!wasm_runtime_validate_native_addr(instance, ctx,
+                                           sizeof(graph_execution_context))) {
+        NN_ERR_PRINTF("ctx is invalid");
+        return invalid_argument;
+    }
+
+    res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(g, ctx);
+    *ctx = g;
+    NN_DBG_PRINTF(
+        "wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
+        *ctx);
+    return res;
+}
+
+error
+wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
+                  uint32_t index, tensor_wasm *input_tensor)
+{
+    NN_DBG_PRINTF("Running wasi_nn_set_input [ctx=%d, index=%d]...", ctx,
+                  index);
+
+    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
+    bh_assert(instance);
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+
+    error res;
+    if (success != (res = is_model_initialized(wasi_nn_ctx)))
+        return res;
+
+    tensor input_tensor_native = { 0 };
+    if (success
+        != (res = tensor_app_native(instance, input_tensor,
+                                    &input_tensor_native)))
+        return res;
+
+    res = lookup[wasi_nn_ctx->current_encoding].set_input(ctx, index,
+                                                          &input_tensor_native);
+
+    // XXX: Free intermediate structure pointers
+    if (input_tensor_native.dimensions)
+        wasm_runtime_free(input_tensor_native.dimensions);
+
+    NN_DBG_PRINTF("wasi_nn_set_input finished with status %d", res);
+    return res;
+}
+
+error
+wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
+{
+    NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
+
+    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
+    bh_assert(instance);
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+
+    error res;
+    if (success != (res = is_model_initialized(wasi_nn_ctx)))
+        return res;
+
+    res = lookup[wasi_nn_ctx->current_encoding].compute(ctx);
+    NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
+    return res;
+}
+
+error
+wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
+                   uint32_t index, tensor_data output_tensor,
+                   uint32_t *output_tensor_size)
+{
+    NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx,
+                  index);
+
+    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
+    bh_assert(instance);
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+
+    error res;
+    if (success != (res = is_model_initialized(wasi_nn_ctx)))
+        return res;
+
+    if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
+                                           sizeof(uint32_t))) {
+        NN_ERR_PRINTF("output_tensor_size is invalid");
+        return invalid_argument;
+    }
+
+    res = lookup[wasi_nn_ctx->current_encoding].get_output(
+        ctx, index, output_tensor, output_tensor_size);
+    NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
+                  res, *output_tensor_size);
+    return res;
+}
+
+/* Non-exposed public functions */
+
+WASINNContext *
+wasi_nn_initialize()
+{
+    NN_DBG_PRINTF("Initializing wasi-nn");
+    WASINNContext *wasi_nn_ctx =
+        (WASINNContext *)wasm_runtime_malloc(sizeof(WASINNContext));
+    if (wasi_nn_ctx == NULL) {
+        NN_ERR_PRINTF("Error when allocating memory for WASI-NN context");
+        return NULL;
+    }
+    wasi_nn_ctx->is_initialized = true;
+    wasi_nn_ctx->current_encoding = 3;
+    return wasi_nn_ctx;
+}
+
+void
+wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
+{
+    if (wasi_nn_ctx == NULL) {
+        NN_ERR_PRINTF(
+            "Error when deallocating memory. WASI-NN context is NULL");
+        return;
+    }
+    NN_DBG_PRINTF("Freeing wasi-nn");
+    NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
+    NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
+    tensorflowlite_destroy();
+    wasm_runtime_free(wasi_nn_ctx);
+}
+
+/* Register WASI-NN in WAMR */
+
+/* clang-format off */
+#define REG_NATIVE_FUNC(func_name, signature) \
+    { #func_name, wasi_nn_##func_name, signature, NULL }
+/* clang-format on */
+
+static NativeSymbol native_symbols_wasi_nn[] = {
+    REG_NATIVE_FUNC(load, "(*ii*)i"),
+    REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
+    REG_NATIVE_FUNC(set_input, "(ii*)i"),
+    REG_NATIVE_FUNC(compute, "(i)i"),
+    REG_NATIVE_FUNC(get_output, "(ii**)i"),
+};
+
+uint32_t
+get_wasi_nn_export_apis(NativeSymbol **p_libc_wasi_apis)
+{
+    *p_libc_wasi_apis = native_symbols_wasi_nn;
+    return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
+}

+ 30 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

@@ -0,0 +1,30 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_PRIVATE_H
+#define WASI_NN_PRIVATE_H
+
+#include "wasi_nn_types.h"
+
+typedef struct {
+    bool is_initialized;
+    graph_encoding current_encoding;
+} WASINNContext;
+
+/**
+ * @brief Initialize wasi-nn
+ *
+ */
+WASINNContext *
+wasi_nn_initialize();
+/**
+ * @brief Destroy wasi-nn on app exists
+ *
+ */
+
+void
+wasi_nn_destroy(WASINNContext *wasi_nn_ctx);
+
+#endif

+ 40 - 18
core/iwasm/libraries/wasi-nn/wasi_nn_tensorflow.cpp → core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

@@ -3,8 +3,10 @@
  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  */
 
-#include "wasi_nn_tensorflow.hpp"
-#include "wasi_nn_common.h"
+#include "wasi_nn.h"
+#include "wasi_nn_tensorflowlite.hpp"
+#include "logger.h"
+
 #include "bh_common.h"
 #include "bh_platform.h"
 #include "platform_common.h"
@@ -25,21 +27,21 @@ static char *model_pointer = NULL;
 /* WASI-NN (tensorflow) implementation */
 
 error
-tensorflow_load(graph_builder_array builder, graph_encoding encoding,
-                execution_target target, graph *graph)
+tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
+                    execution_target target, graph *g)
 {
     if (model_pointer != NULL) {
         wasm_runtime_free(model_pointer);
         model_pointer = NULL;
     }
 
-    if (builder.size != 1) {
+    if (builder->size != 1) {
         NN_ERR_PRINTF("Unexpected builder format.");
         return invalid_argument;
     }
 
-    if (encoding != tensorflow) {
-        NN_ERR_PRINTF("Encoding is not tensorflow.");
+    if (encoding != tensorflowlite) {
+        NN_ERR_PRINTF("Encoding is not tensorflowlite.");
         return invalid_argument;
     }
 
@@ -48,7 +50,7 @@ tensorflow_load(graph_builder_array builder, graph_encoding encoding,
         return invalid_argument;
     }
 
-    uint32_t size = builder.buf[0].size;
+    uint32_t size = builder->buf[0].size;
 
     model_pointer = (char *)wasm_runtime_malloc(size);
     if (model_pointer == NULL) {
@@ -56,7 +58,7 @@ tensorflow_load(graph_builder_array builder, graph_encoding encoding,
         return missing_memory;
     }
 
-    bh_memcpy_s(model_pointer, size, builder.buf[0].buf, size);
+    bh_memcpy_s(model_pointer, size, builder->buf[0].buf, size);
 
     model = tflite::FlatBufferModel::BuildFromBuffer(model_pointer, size, NULL);
     if (model == NULL) {
@@ -81,7 +83,7 @@ tensorflow_load(graph_builder_array builder, graph_encoding encoding,
 }
 
 error
-tensorflow_init_execution_context(graph graph)
+tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx)
 {
     if (interpreter == NULL) {
         NN_ERR_PRINTF("Non-initialized interpreter.");
@@ -92,8 +94,8 @@ tensorflow_init_execution_context(graph graph)
 }
 
 error
-tensorflow_set_input(graph_execution_context ctx, uint32_t index,
-                     tensor *input_tensor)
+tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
+                         tensor *input_tensor)
 {
     if (interpreter == NULL) {
         NN_ERR_PRINTF("Non-initialized interpreter.");
@@ -113,11 +115,11 @@ tensorflow_set_input(graph_execution_context ctx, uint32_t index,
     }
 
     uint32_t model_tensor_size = 1;
-    for (int i = 0; i < (int)tensor->dims->size; ++i)
+    for (int i = 0; i < tensor->dims->size; ++i)
         model_tensor_size *= (uint32_t)tensor->dims->data[i];
 
     uint32_t input_tensor_size = 1;
-    for (int i = 0; i < input_tensor->dimensions->size; i++)
+    for (uint32_t i = 0; i < input_tensor->dimensions->size; i++)
         input_tensor_size *= (uint32_t)input_tensor->dimensions->buf[i];
 
     if (model_tensor_size != input_tensor_size) {
@@ -136,7 +138,7 @@ tensorflow_set_input(graph_execution_context ctx, uint32_t index,
 }
 
 error
-tensorflow_compute(graph_execution_context ctx)
+tensorflowlite_compute(graph_execution_context ctx)
 {
     if (interpreter == NULL) {
         NN_ERR_PRINTF("Non-initialized interpreter.");
@@ -147,8 +149,9 @@ tensorflow_compute(graph_execution_context ctx)
 }
 
 error
-tensorflow_get_output(graph_execution_context context, uint32_t index,
-                      tensor_data output_tensor, uint32_t *output_tensor_size)
+tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
+                          tensor_data output_tensor,
+                          uint32_t *output_tensor_size)
 {
     if (interpreter == NULL) {
         NN_ERR_PRINTF("Non-initialized interpreter.");
@@ -178,7 +181,7 @@ tensorflow_get_output(graph_execution_context context, uint32_t index,
     }
 
     float *tensor_f = interpreter->typed_output_tensor<float>(index);
-    for (int i = 0; i < model_tensor_size; ++i)
+    for (uint32_t i = 0; i < model_tensor_size; ++i)
         NN_DBG_PRINTF("output: %f", tensor_f[i]);
 
     *output_tensor_size = model_tensor_size;
@@ -186,3 +189,22 @@ tensorflow_get_output(graph_execution_context context, uint32_t index,
                 model_tensor_size * sizeof(float));
     return success;
 }
+
+void
+tensorflowlite_destroy()
+{
+    /*
+        TensorFlow Lite memory is man
+
+        Related issues:
+        * https://github.com/tensorflow/tensorflow/issues/15880
+    */
+    NN_DBG_PRINTF("Freeing memory.");
+    model.reset(nullptr);
+    model = NULL;
+    interpreter.reset(nullptr);
+    interpreter = NULL;
+    wasm_runtime_free(model_pointer);
+    model_pointer = NULL;
+    NN_DBG_PRINTF("Memory free'd.");
+}

+ 41 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp

@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_TENSORFLOWLITE_HPP
+#define WASI_NN_TENSORFLOWLITE_HPP
+
+#include "wasi_nn.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+error
+tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
+                    execution_target target, graph *g);
+
+error
+tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx);
+
+error
+tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
+                         tensor *input_tensor);
+
+error
+tensorflowlite_compute(graph_execution_context ctx);
+
+error
+tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
+                          tensor_data output_tensor,
+                          uint32_t *output_tensor_size);
+
+void
+tensorflowlite_destroy();
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif

+ 1 - 0
core/iwasm/libraries/wasi-nn/test/.dockerignore

@@ -0,0 +1 @@
+Dockerfile

+ 11 - 5
core/iwasm/libraries/wasi-nn/test/Dockerfile

@@ -8,18 +8,24 @@ ENV DEBIAN_FRONTEND=noninteractive
 RUN apt-get update && apt-get install -y \
     cmake build-essential git wget python3.10 python3-pip
 
-RUN wget -q https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-14/wasi-sdk-14.0-linux.tar.gz && \
-    tar xf wasi-sdk-*-linux.tar.gz -C /opt && rm -f wasi-sdk-*-linux.tar.gz && \
-    mv /opt/wasi-sdk-14.0 /opt/wasi-sdk
+ARG WASI_SDK_VER=16
+RUN wget -c --progress=dot:giga https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-${WASI_SDK_VER}/wasi-sdk-${WASI_SDK_VER}.0-linux.tar.gz -P /opt \
+  && tar xf /opt/wasi-sdk-${WASI_SDK_VER}.0-linux.tar.gz -C /opt \
+  && ln -fs /opt/wasi-sdk-${WASI_SDK_VER}.0 /opt/wasi-sdk \
+  && rm /opt/wasi-sdk-${WASI_SDK_VER}.0-linux.tar.gz
 
 WORKDIR /home/wamr
 
+COPY core/deps/install_tensorflow.sh core/deps/install_tensorflow.sh
+RUN ./core/deps/install_tensorflow.sh
+
+COPY core/iwasm/libraries/wasi-nn/test/requirements.txt .
+RUN pip3 install -r requirements.txt
+
 COPY core core
 COPY build-scripts build-scripts
 COPY product-mini product-mini
 
-RUN pip3 install -r core/iwasm/libraries/wasi-nn/test/requirements.txt
-
 WORKDIR /home/wamr/core/iwasm/libraries/wasi-nn/test/build
 
 RUN cmake -DWAMR_BUILD_WASI_NN=1 ..

+ 7 - 8
core/iwasm/libraries/wasi-nn/test/test_tensorflow.c

@@ -28,7 +28,7 @@ typedef struct {
 // WASI-NN wrappers
 
 error
-wasm_load(char *model_name, graph *graph)
+wasm_load(char *model_name, graph *g)
 {
     FILE *pFile = fopen(model_name, "r");
     if (pFile == NULL)
@@ -64,7 +64,7 @@ wasm_load(char *model_name, graph *graph)
     arr.buf[0].size = result;
     arr.buf[0].buf = buffer;
 
-    error res = load(&arr, tensorflow, cpu, graph);
+    error res = load(&arr, tensorflowlite, cpu, g);
 
     fclose(pFile);
     free(buffer);
@@ -73,13 +73,13 @@ wasm_load(char *model_name, graph *graph)
 }
 
 error
-wasm_init_execution_context(graph graph, graph_execution_context *ctx)
+wasm_init_execution_context(graph g, graph_execution_context *ctx)
 {
-    return init_execution_context(graph, ctx);
+    return init_execution_context(g, ctx);
 }
 
 error
-wasm_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
+wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
 {
     tensor_dimensions dims;
     dims.size = INPUT_TENSOR_DIMS;
@@ -130,7 +130,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size,
         exit(1);
     }
 
-    if (wasm_input(ctx, input, input_size) != success) {
+    if (wasm_set_input(ctx, input, input_size) != success) {
         fprintf(stderr, "Error when setting input tensor.");
         exit(1);
     }
@@ -151,7 +151,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size,
         *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size;
         if (wasm_get_output(ctx, i, &out_tensor[offset], output_size)
             != success) {
-            fprintf(stderr, "Error when getting input .");
+            fprintf(stderr, "Error when getting output .");
             exit(1);
         }
 
@@ -295,7 +295,6 @@ main()
     test_mult_dimensions();
     printf("################### Testing multiple outputs...\n");
     test_mult_outputs();
-
     printf("Tests: passed!\n");
     return 0;
 }

+ 10 - 1
core/iwasm/libraries/wasi-nn/wasi_nn.cmake

@@ -5,6 +5,15 @@ set (WASI_NN_DIR ${CMAKE_CURRENT_LIST_DIR})
 
 add_definitions (-DWASM_ENABLE_WASI_NN=1)
 
-set (LIBC_WASI_NN_SOURCE ${WASI_NN_DIR}/wasi_nn_native.c ${WASI_NN_DIR}/wasi_nn_tensorflow.cpp)
+include_directories (${WASI_NN_DIR})
+include_directories (${WASI_NN_DIR}/src)
+include_directories (${WASI_NN_DIR}/src/utils)
+
+set (
+    LIBC_WASI_NN_SOURCE
+    ${WASI_NN_DIR}/src/wasi_nn.c
+    ${WASI_NN_DIR}/src/wasi_nn_tensorflowlite.cpp
+    ${WASI_NN_DIR}/src/utils/wasi_nn_app_native.c
+)
 
 set (TENSORFLOW_LIB tensorflow-lite)

+ 19 - 62
core/iwasm/libraries/wasi-nn/wasi_nn.h

@@ -3,63 +3,17 @@
  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  */
 
-#ifndef WASI_NN_WASM_H
-#define WASI_NN_WASM_H
-
-#include "wasi_nn_common.h"
-
 /**
  * Following definition from:
- * [Aug 10th, 2022]
- * https://github.com/WebAssembly/wasi-nn/blob/e5e1a6c31f424c7cd63026cd270e9746775675a0/wasi-nn.wit.md
+ * [Oct 25th, 2022]
+ * https://github.com/WebAssembly/wasi-nn/blob/0f77c48ec195748990ff67928a4b3eef5f16c2de/wasi-nn.wit.md
  */
 
-/* The graph initialization data. */
-
-// This consists of an array of buffers because implementing backends may encode
-// their graph IR in parts (e.g., OpenVINO stores its IR and weights
-// separately).
-typedef struct {
-    uint8_t *buf;
-    uint32_t size;
-} graph_builder;
-
-typedef struct {
-    graph_builder *buf;
-    uint32_t size;
-} graph_builder_array;
-
-/* The dimensions of a tensor. */
-
-// The array length matches the tensor rank and each element in the array
-// describes the size of each dimension.
-typedef struct {
-    uint32_t *buf;
-    uint32_t size;
-} tensor_dimensions;
-
-/* The tensor data. */
+#ifndef WASI_NN_H
+#define WASI_NN_H
 
-// Initially conceived as a sparse representation, each empty cell would be
-// filled with zeros and the array length must match the product of all of the
-// dimensions and the number of bytes in the type (e.g., a 2x2 tensor with
-// 4-byte f32 elements would have a data array of length 16). Naturally, this
-// representation requires some knowledge of how to lay out data in
-// memory--e.g., using row-major ordering--and could perhaps be improved.
-typedef uint8_t *tensor_data;
-
-/* A tensor. */
-
-typedef struct {
-    // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
-    // represent a tensor containing a single value, use `[1]` for the tensor
-    // dimensions.
-    tensor_dimensions *dimensions;
-    // Describe the type of element in the tensor (e.g., f32).
-    tensor_type type;
-    // Contains the tensor data.
-    tensor_data data;
-} tensor;
+#include <stdint.h>
+#include "wasi_nn_types.h"
 
 /**
  * @brief Load an opaque sequence of bytes to use for inference.
@@ -67,25 +21,31 @@ typedef struct {
  * @param builder   Model builder.
  * @param encoding  Model encoding.
  * @param target    Execution target.
- * @param graph     Graph.
+ * @param g         Graph.
  * @return error    Execution status.
  */
 error
 load(graph_builder_array *builder, graph_encoding encoding,
-     execution_target target, graph *graph)
-    __attribute__((export_module("wasi_nn")))
+     execution_target target, graph *g)
     __attribute__((import_module("wasi_nn")));
 
+/**
+ * INFERENCE
+ *
+ */
+
+// Bind a `graph` to the input and output tensors for an inference.
+typedef uint32_t graph_execution_context;
+
 /**
  * @brief Create an execution instance of a loaded graph.
  *
- * @param graph     Graph.
+ * @param g         Graph.
  * @param ctx       Execution context.
  * @return error    Execution status.
  */
 error
-init_execution_context(graph graph, graph_execution_context *ctx)
-    __attribute__((export_module("wasi_nn")))
+init_execution_context(graph g, graph_execution_context *ctx)
     __attribute__((import_module("wasi_nn")));
 
 /**
@@ -98,7 +58,6 @@ init_execution_context(graph graph, graph_execution_context *ctx)
  */
 error
 set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
-    __attribute__((export_module("wasi_nn")))
     __attribute__((import_module("wasi_nn")));
 
 /**
@@ -108,8 +67,7 @@ set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
  * @return error    Execution status.
  */
 error
-compute(graph_execution_context ctx) __attribute__((export_module("wasi_nn")))
-__attribute__((import_module("wasi_nn")));
+compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));
 
 /**
  * @brief Extract the outputs after inference.
@@ -126,7 +84,6 @@ __attribute__((import_module("wasi_nn")));
 error
 get_output(graph_execution_context ctx, uint32_t index,
            tensor_data output_tensor, uint32_t *output_tensor_size)
-    __attribute__((export_module("wasi_nn")))
     __attribute__((import_module("wasi_nn")));
 
 #endif

+ 0 - 44
core/iwasm/libraries/wasi-nn/wasi_nn_common.h

@@ -1,44 +0,0 @@
-/*
- * Copyright (C) 2019 Intel Corporation.  All rights reserved.
- * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- */
-
-#ifndef WASI_NN_COMMON_H
-#define WASI_NN_COMMON_H
-
-#include <stdint.h>
-
-// The type of the elements in a tensor.
-typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
-
-// Describes the encoding of the graph. This allows the API to be implemented by
-// various backends that encode (i.e., serialize) their graph IR with different
-// formats.
-typedef enum { openvino = 0, onnx, tensorflow, pytorch } graph_encoding;
-
-// Define where the graph should be executed.
-typedef enum { cpu = 0, gpu, tpu } execution_target;
-
-// Error codes returned by functions in this API.
-typedef enum {
-    // No error occurred.
-    success = 0,
-    // Caller module passed an invalid argument.
-    invalid_argument,
-    // Invalid encoding.
-    invalid_encoding,
-    // Caller module is missing a memory export.
-    missing_memory,
-    // Device or resource busy.
-    busy,
-    // Runtime Error.
-    runtime_error,
-} error;
-
-// An execution graph for performing inference (i.e., a model).
-typedef uint32_t graph;
-
-// Bind a `graph` to the input and output tensors for an inference.
-typedef uint32_t graph_execution_context;
-
-#endif

+ 0 - 264
core/iwasm/libraries/wasi-nn/wasi_nn_native.c

@@ -1,264 +0,0 @@
-/*
- * Copyright (C) 2019 Intel Corporation.  All rights reserved.
- * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- */
-
-#include <stdio.h>
-#include <assert.h>
-#include <errno.h>
-#include <string.h>
-#include <stdlib.h>
-
-#include "wasi_nn_common.h"
-#include "wasm_export.h"
-#include "bh_platform.h"
-
-#include "wasi_nn.h"
-#include "wasi_nn_tensorflow.hpp"
-#include "logger.h"
-
-/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
-
-typedef struct {
-    uint32_t buf_offset;
-    uint32_t size;
-} graph_builder_wasm;
-
-typedef struct {
-    uint32_t buf_offset;
-    uint32_t size;
-} graph_builder_array_wasm;
-
-typedef struct {
-    uint32_t dimensions_offset;
-    tensor_type type;
-    uint32_t data_offset;
-} tensor_wasm;
-
-typedef struct {
-    uint32_t buf_offset;
-    uint32_t size;
-} tensor_dimensions_wasm;
-
-/* Global variables */
-
-static uint8_t _is_initialized;
-static graph_encoding _encoding;
-
-/* Utils */
-
-static error
-check_initialized()
-{
-    if (!_is_initialized) {
-        NN_ERR_PRINTF("Model not initialized.");
-        return invalid_argument;
-    }
-    if (_encoding != tensorflow) {
-        NN_ERR_PRINTF("Model encoding is not tensorflow.");
-        return invalid_argument;
-    }
-    return success;
-}
-
-/* WASI-NN implementation */
-
-error
-wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
-             graph_encoding encoding, execution_target target, graph *graph)
-{
-    NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
-                  target);
-
-    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
-    bh_assert(instance);
-
-    if (!wasm_runtime_validate_native_addr(instance, builder,
-                                           sizeof(graph_builder_array_wasm)))
-        return invalid_argument;
-
-    if (!wasm_runtime_validate_app_addr(instance, builder->buf_offset,
-                                        builder->size * sizeof(uint32_t)))
-        return invalid_argument;
-
-    NN_DBG_PRINTF("Graph builder array contains %d elements", builder->size);
-
-    graph_builder_wasm *gb_wasm =
-        (graph_builder_wasm *)wasm_runtime_addr_app_to_native(
-            instance, builder->buf_offset);
-
-    graph_builder *gb_native = (graph_builder *)wasm_runtime_malloc(
-        builder->size * sizeof(graph_builder));
-    if (gb_native == NULL)
-        return missing_memory;
-
-    for (int i = 0; i < builder->size; ++i) {
-        if (!wasm_runtime_validate_app_addr(instance, gb_wasm[i].buf_offset,
-                                            gb_wasm[i].size
-                                                * sizeof(uint8_t))) {
-            wasm_runtime_free(gb_native);
-            return invalid_argument;
-        }
-
-        gb_native[i].buf = (uint8_t *)wasm_runtime_addr_app_to_native(
-            instance, gb_wasm[i].buf_offset);
-        gb_native[i].size = gb_wasm[i].size;
-
-        NN_DBG_PRINTF("Graph builder %d contains %d elements", i,
-                      gb_wasm[i].size);
-    }
-
-    graph_builder_array gba_native = { .buf = gb_native,
-                                       .size = builder->size };
-
-    if (!wasm_runtime_validate_native_addr(instance, graph, sizeof(graph))) {
-        wasm_runtime_free(gb_native);
-        return invalid_argument;
-    }
-
-    switch (encoding) {
-        case tensorflow:
-            break;
-        default:
-            NN_ERR_PRINTF("Only tensorflow is supported.");
-            wasm_runtime_free(gb_native);
-            return invalid_argument;
-    }
-
-    _encoding = encoding;
-    _is_initialized = 1;
-
-    error res = tensorflow_load(gba_native, _encoding, target, graph);
-    NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res,
-                  *graph);
-
-    wasm_runtime_free(gb_native);
-    return res;
-}
-
-error
-wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph graph,
-                               graph_execution_context *ctx)
-{
-    NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...",
-                  graph);
-    error res;
-    if (success != (res = check_initialized()))
-        return res;
-    res = tensorflow_init_execution_context(graph);
-    *ctx = graph;
-    NN_DBG_PRINTF(
-        "wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
-        *ctx);
-    return res;
-}
-
-error
-wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
-                  uint32_t index, tensor_wasm *input_tensor)
-{
-    NN_DBG_PRINTF("Running wasi_nn_set_input [ctx=%d, index=%d]...", ctx,
-                  index);
-
-    error res;
-    if (success != (res = check_initialized()))
-        return res;
-
-    wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
-    bh_assert(instance);
-
-    if (!wasm_runtime_validate_native_addr(instance, input_tensor,
-                                           sizeof(tensor_wasm)))
-        return invalid_argument;
-
-    if (!wasm_runtime_validate_app_addr(
-            instance, input_tensor->dimensions_offset, sizeof(uint32_t)))
-        return invalid_argument;
-
-    tensor_dimensions_wasm *dimensions_w =
-        (tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native(
-            instance, input_tensor->dimensions_offset);
-
-    if (!wasm_runtime_validate_app_addr(instance, dimensions_w->buf_offset,
-                                        dimensions_w->size * sizeof(uint32_t)))
-        return invalid_argument;
-
-    tensor_dimensions dimensions = {
-        .buf = (uint32_t *)wasm_runtime_addr_app_to_native(
-            instance, dimensions_w->buf_offset),
-        .size = dimensions_w->size
-    };
-
-    NN_DBG_PRINTF("Number of dimensions: %d", dimensions.size);
-    int total_elements = 1;
-    for (int i = 0; i < dimensions.size; ++i) {
-        NN_DBG_PRINTF("Dimension %d: %d", i, dimensions.buf[i]);
-        total_elements *= dimensions.buf[i];
-    }
-    NN_DBG_PRINTF("Tensor type: %d", input_tensor->type);
-
-    if (!wasm_runtime_validate_app_addr(instance, input_tensor->data_offset,
-                                        total_elements))
-        return invalid_argument;
-
-    tensor tensor = { .type = input_tensor->type,
-                      .dimensions = &dimensions,
-                      .data = (uint8_t *)wasm_runtime_addr_app_to_native(
-                          instance, input_tensor->data_offset) };
-
-    res = tensorflow_set_input(ctx, index, &tensor);
-    NN_DBG_PRINTF("wasi_nn_set_input finished with status %d", res);
-    return res;
-}
-
-error
-wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
-{
-    NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
-    error res;
-    if (success != (res = check_initialized()))
-        return res;
-
-    res = tensorflow_compute(ctx);
-    NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
-    return res;
-}
-
-error
-wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
-                   uint32_t index, tensor_data output_tensor,
-                   uint32_t *output_tensor_size)
-{
-    NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx,
-                  index);
-    error res;
-    if (success != (res = check_initialized()))
-        return res;
-
-    res = tensorflow_get_output(ctx, index, output_tensor, output_tensor_size);
-    NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
-                  res, *output_tensor_size);
-    return res;
-}
-
-/* Register WASI-NN in WAMR */
-
-/* clang-format off */
-#define REG_NATIVE_FUNC(func_name, signature) \
-    { #func_name, wasi_nn_##func_name, signature, NULL }
-/* clang-format on */
-
-static NativeSymbol native_symbols_wasi_nn[] = {
-    REG_NATIVE_FUNC(load, "(*ii*)i"),
-    REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
-    REG_NATIVE_FUNC(set_input, "(ii*)i"),
-    REG_NATIVE_FUNC(compute, "(i)i"),
-    REG_NATIVE_FUNC(get_output, "(ii**)i"),
-};
-
-uint32_t
-get_wasi_nn_export_apis(NativeSymbol **p_libc_wasi_apis)
-{
-    *p_libc_wasi_apis = native_symbols_wasi_nn;
-    return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
-}

+ 0 - 40
core/iwasm/libraries/wasi-nn/wasi_nn_tensorflow.hpp

@@ -1,40 +0,0 @@
-/*
- * Copyright (C) 2019 Intel Corporation.  All rights reserved.
- * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- */
-
-#ifndef WASI_NN_TENSORFLOW_HPP
-#define WASI_NN_TENSORFLOW_HPP
-
-#include <stdio.h>
-
-#include "wasi_nn.h"
-#include "logger.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-error
-tensorflow_load(graph_builder_array builder, graph_encoding encoding,
-                execution_target target, graph *graph);
-
-error
-tensorflow_init_execution_context(graph graph);
-
-error
-tensorflow_set_input(graph_execution_context ctx, uint32_t index,
-                     tensor *input_tensor);
-
-error
-tensorflow_compute(graph_execution_context ctx);
-
-error
-tensorflow_get_output(graph_execution_context context, uint32_t index,
-                      tensor_data output_tensor, uint32_t *output_tensor_size);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif

+ 106 - 0
core/iwasm/libraries/wasi-nn/wasi_nn_types.h

@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_TYPES_H
+#define WASI_NN_TYPES_H
+
+/**
+ * ERRORS
+ *
+ */
+
+// Error codes returned by functions in this API.
+typedef enum {
+    // No error occurred.
+    success = 0,
+    // Caller module passed an invalid argument.
+    invalid_argument,
+    // Invalid encoding.
+    invalid_encoding,
+    // Caller module is missing a memory export.
+    missing_memory,
+    // Device or resource busy.
+    busy,
+    // Runtime Error.
+    runtime_error,
+} error;
+
+/**
+ * TENSOR
+ *
+ */
+
+// The dimensions of a tensor.
+//
+// The array length matches the tensor rank and each element in the array
+// describes the size of each dimension.
+typedef struct {
+    uint32_t *buf;
+    uint32_t size;
+} tensor_dimensions;
+
+// The type of the elements in a tensor.
+typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
+
+// The tensor data.
+//
+// Initially conceived as a sparse representation, each empty cell would be
+// filled with zeros and the array length must match the product of all of the
+// dimensions and the number of bytes in the type (e.g., a 2x2 tensor with
+// 4-byte f32 elements would have a data array of length 16). Naturally, this
+// representation requires some knowledge of how to lay out data in
+// memory--e.g., using row-major ordering--and could perhaps be improved.
+typedef uint8_t *tensor_data;
+
+// A tensor.
+typedef struct {
+    // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
+    // represent a tensor containing a single value, use `[1]` for the tensor
+    // dimensions.
+    tensor_dimensions *dimensions;
+    // Describe the type of element in the tensor (e.g., f32).
+    tensor_type type;
+    // Contains the tensor data.
+    tensor_data data;
+} tensor;
+
+/**
+ * GRAPH
+ *
+ */
+
+// The graph initialization data.
+//
+// This consists of an array of buffers because implementing backends may encode
+// their graph IR in parts (e.g., OpenVINO stores its IR and weights
+// separately).
+typedef struct {
+    uint8_t *buf;
+    uint32_t size;
+} graph_builder;
+
+typedef struct {
+    graph_builder *buf;
+    uint32_t size;
+} graph_builder_array;
+
+// An execution graph for performing inference (i.e., a model).
+typedef uint32_t graph;
+
+// Describes the encoding of the graph. This allows the API to be implemented by
+// various backends that encode (i.e., serialize) their graph IR with different
+// formats.
+typedef enum {
+    openvino = 0,
+    onnx,
+    tensorflow,
+    pytorch,
+    tensorflowlite
+} graph_encoding;
+
+// Define where the graph should be executed.
+typedef enum execution_target { cpu = 0, gpu, tpu } execution_target;
+
+#endif