Sfoglia il codice sorgente

wasi-nn: Support multiple TFLite models (#2002)

Remove restrictions:
- Only 1 WASM app at a time
- Only 1 model at a time
   - `graph` and `graph-execution-context` are ignored

Refer to previous document:
https://github.com/bytecodealliance/wasm-micro-runtime/blob/e8d718096dc56d4b1aa66ec6cd04d6024ca1c6e2/core/iwasm/libraries/wasi-nn/README.md
tonibofarull 2 anni fa
parent
commit
a15a731e12

+ 5 - 0
build-scripts/config_common.cmake

@@ -333,6 +333,11 @@ if (WAMR_BUILD_SGX_IPFS EQUAL 1)
 endif ()
 if (WAMR_BUILD_WASI_NN EQUAL 1)
   message ("     WASI-NN enabled")
+  add_definitions (-DWASM_ENABLE_WASI_NN=1)
+  if (WASI_NN_ENABLE_GPU EQUAL 1)
+      message ("     WASI-NN: GPU enabled")
+      add_definitions (-DWASI_NN_ENABLE_GPU=1)
+  endif ()
 endif ()
 if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1)
   add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1)

+ 7 - 0
build-scripts/runtime_lib.cmake

@@ -109,6 +109,13 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
         message("Tensorflow is already downloaded.")
     endif()
     set(TENSORFLOW_SOURCE_DIR "${WAMR_ROOT_DIR}/core/deps/tensorflow-src")
+
+    if (WASI_NN_ENABLE_GPU EQUAL 1)
+        # Tensorflow specific:
+        # * https://www.tensorflow.org/lite/guide/build_cmake#available_options_to_build_tensorflow_lite
+        set (TFLITE_ENABLE_GPU ON)
+    endif ()
+
     include_directories (${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/include)
     include_directories (${TENSORFLOW_SOURCE_DIR})
     add_subdirectory(

+ 0 - 9
core/iwasm/libraries/wasi-nn/README.md

@@ -19,12 +19,6 @@ To run the tests we assume that the current directory is the root of the reposit
 
 ### Build the runtime
 
-Build the runtime base image,
-
-```
-docker build -t wasi-nn-base -f core/iwasm/libraries/wasi-nn/test/Dockerfile.base .
-```
-
 Build the runtime image for your execution target type.
 
 `EXECUTION_TYPE` can be:
@@ -84,9 +78,6 @@ Requirements:
 
 Supported:
 
-* Only 1 WASM app at a time.
-* Only 1 model at a time.
-    * `graph` and `graph-execution-context` are ignored.
 * Graph encoding: `tensorflowlite`.
 * Execution target: `cpu` and `gpu`.
 * Tensor type: `fp32`.

+ 36 - 30
core/iwasm/libraries/wasi-nn/src/utils/logger.h

@@ -13,51 +13,57 @@
     (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
 
 /* Disable a level by removing the define */
-#define ENABLE_ERR_LOG
-#define ENABLE_WARN_LOG
-#define ENABLE_DBG_LOG
-#define ENABLE_INFO_LOG
+#ifndef NN_LOG_LEVEL
+/*
+    0 -> debug, info, warn, err
+    1 -> info, warn, err
+    2 -> warn, err
+    3 -> err
+    4 -> NO LOGS
+*/
+#define NN_LOG_LEVEL 0
+#endif
 
 // Definition of the levels
-#ifdef ENABLE_ERR_LOG
-#define NN_ERR_PRINTF(fmt, ...)                                        \
-    do {                                                               \
-        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-        printf("\n");                                                  \
-        fflush(stdout);                                                \
+#if NN_LOG_LEVEL <= 3
+#define NN_ERR_PRINTF(fmt, ...)                                              \
+    do {                                                                     \
+        printf("[%s:%d ERROR] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                        \
+        fflush(stdout);                                                      \
     } while (0)
 #else
 #define NN_ERR_PRINTF(fmt, ...)
 #endif
-#ifdef ENABLE_WARN_LOG
-#define NN_WARN_PRINTF(fmt, ...)                                       \
-    do {                                                               \
-        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-        printf("\n");                                                  \
-        fflush(stdout);                                                \
+#if NN_LOG_LEVEL <= 2
+#define NN_WARN_PRINTF(fmt, ...)                                               \
+    do {                                                                       \
+        printf("[%s:%d WARNING] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                          \
+        fflush(stdout);                                                        \
     } while (0)
 #else
 #define NN_WARN_PRINTF(fmt, ...)
 #endif
-#ifdef ENABLE_DBG_LOG
-#define NN_DBG_PRINTF(fmt, ...)                                        \
-    do {                                                               \
-        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-        printf("\n");                                                  \
-        fflush(stdout);                                                \
+#if NN_LOG_LEVEL <= 1
+#define NN_INFO_PRINTF(fmt, ...)                                            \
+    do {                                                                    \
+        printf("[%s:%d INFO] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                       \
+        fflush(stdout);                                                     \
     } while (0)
 #else
-#define NN_DBG_PRINTF(fmt, ...)
+#define NN_INFO_PRINTF(fmt, ...)
 #endif
-#ifdef ENABLE_INFO_LOG
-#define NN_INFO_PRINTF(fmt, ...)                                       \
-    do {                                                               \
-        printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
-        printf("\n");                                                  \
-        fflush(stdout);                                                \
+#if NN_LOG_LEVEL <= 0
+#define NN_DBG_PRINTF(fmt, ...)                                              \
+    do {                                                                     \
+        printf("[%s:%d DEBUG] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
+        printf("\n");                                                        \
+        fflush(stdout);                                                      \
     } while (0)
 #else
-#define NN_INFO_PRINTF(fmt, ...)
+#define NN_DBG_PRINTF(fmt, ...)
 #endif
 
 #endif

+ 21 - 17
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -22,13 +22,14 @@
 
 /* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
 
-typedef error (*LOAD)(graph_builder_array *, graph_encoding, execution_target,
-                      graph *);
-typedef error (*INIT_EXECUTION_CONTEXT)(graph, graph_execution_context *);
-typedef error (*SET_INPUT)(graph_execution_context, uint32_t, tensor *);
-typedef error (*COMPUTE)(graph_execution_context);
-typedef error (*GET_OUTPUT)(graph_execution_context, uint32_t, tensor_data,
-                            uint32_t *);
+typedef error (*LOAD)(void *, graph_builder_array *, graph_encoding,
+                      execution_target, graph *);
+typedef error (*INIT_EXECUTION_CONTEXT)(void *, graph,
+                                        graph_execution_context *);
+typedef error (*SET_INPUT)(void *, graph_execution_context, uint32_t, tensor *);
+typedef error (*COMPUTE)(void *, graph_execution_context);
+typedef error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
+                            tensor_data, uint32_t *);
 
 typedef struct {
     LOAD load;
@@ -123,12 +124,12 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
         goto fail;
     }
 
-    res = lookup[encoding].load(&builder_native, encoding, target, g);
+    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
+    res = lookup[encoding].load(wasi_nn_ctx->tflite_ctx, &builder_native,
+                                encoding, target, g);
 
     NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
 
-    WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
-
     wasi_nn_ctx->current_encoding = encoding;
     wasi_nn_ctx->is_initialized = true;
 
@@ -160,8 +161,9 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
         return invalid_argument;
     }
 
-    res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(g, ctx);
-    *ctx = g;
+    res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
+        wasi_nn_ctx->tflite_ctx, g, ctx);
+
     NN_DBG_PRINTF(
         "wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
         *ctx);
@@ -189,8 +191,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
                                     &input_tensor_native)))
         return res;
 
-    res = lookup[wasi_nn_ctx->current_encoding].set_input(ctx, index,
-                                                          &input_tensor_native);
+    res = lookup[wasi_nn_ctx->current_encoding].set_input(
+        wasi_nn_ctx->tflite_ctx, ctx, index, &input_tensor_native);
 
     // XXX: Free intermediate structure pointers
     if (input_tensor_native.dimensions)
@@ -213,7 +215,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
     if (success != (res = is_model_initialized(wasi_nn_ctx)))
         return res;
 
-    res = lookup[wasi_nn_ctx->current_encoding].compute(ctx);
+    res = lookup[wasi_nn_ctx->current_encoding].compute(wasi_nn_ctx->tflite_ctx,
+                                                        ctx);
     NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
     return res;
 }
@@ -241,7 +244,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
     }
 
     res = lookup[wasi_nn_ctx->current_encoding].get_output(
-        ctx, index, output_tensor, output_tensor_size);
+        wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
     NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
                   res, *output_tensor_size);
     return res;
@@ -261,6 +264,7 @@ wasi_nn_initialize()
     }
     wasi_nn_ctx->is_initialized = true;
     wasi_nn_ctx->current_encoding = 3;
+    tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
     return wasi_nn_ctx;
 }
 
@@ -275,7 +279,7 @@ wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
     NN_DBG_PRINTF("Freeing wasi-nn");
     NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
     NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
-    tensorflowlite_destroy();
+    tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
     wasm_runtime_free(wasi_nn_ctx);
 }
 

+ 1 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

@@ -11,6 +11,7 @@
 typedef struct {
     bool is_initialized;
     graph_encoding current_encoding;
+    void *tflite_ctx;
 } WASINNContext;
 
 /**

+ 213 - 67
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

@@ -16,25 +16,105 @@
 #include <tensorflow/lite/model.h>
 #include <tensorflow/lite/optional_debug_tools.h>
 #include <tensorflow/lite/error_reporter.h>
-#include <tensorflow/lite/delegates/gpu/delegate.h>
 
-/* Global variables */
+#if defined(WASI_NN_ENABLE_GPU)
+#include <tensorflow/lite/delegates/gpu/delegate.h>
+#endif
+
+/* Maximum number of graphs per WASM instance */
+#define MAX_GRAPHS_PER_INST 10
+/* Maximum number of graph execution context per WASM instance*/
+#define MAX_GRAPH_EXEC_CONTEXTS_PER_INST 10
+
+typedef struct {
+    std::unique_ptr<tflite::Interpreter> interpreter;
+} Interpreter;
+
+typedef struct {
+    char *model_pointer;
+    std::unique_ptr<tflite::FlatBufferModel> model;
+    execution_target target;
+} Model;
+
+typedef struct {
+    uint32_t current_models;
+    Model models[MAX_GRAPHS_PER_INST];
+    uint32_t current_interpreters;
+    Interpreter interpreters[MAX_GRAPH_EXEC_CONTEXTS_PER_INST];
+    korp_mutex g_lock;
+} TFLiteContext;
+
+/* Utils */
+
+static error
+initialize_g(TFLiteContext *tfl_ctx, graph *g)
+{
+    os_mutex_lock(&tfl_ctx->g_lock);
+    if (tfl_ctx->current_models == MAX_GRAPHS_PER_INST) {
+        os_mutex_unlock(&tfl_ctx->g_lock);
+        NN_ERR_PRINTF("Excedded max graphs per WASM instance");
+        return runtime_error;
+    }
+    *g = tfl_ctx->current_models++;
+    os_mutex_unlock(&tfl_ctx->g_lock);
+    return success;
+}
+static error
+initialize_graph_ctx(TFLiteContext *tfl_ctx, graph g,
+                     graph_execution_context *ctx)
+{
+    os_mutex_lock(&tfl_ctx->g_lock);
+    if (tfl_ctx->current_interpreters == MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
+        os_mutex_unlock(&tfl_ctx->g_lock);
+        NN_ERR_PRINTF("Excedded max graph execution context per WASM instance");
+        return runtime_error;
+    }
+    *ctx = tfl_ctx->current_interpreters++;
+    os_mutex_unlock(&tfl_ctx->g_lock);
+    return success;
+}
 
-static std::unique_ptr<tflite::Interpreter> interpreter;
-static std::unique_ptr<tflite::FlatBufferModel> model;
+static error
+is_valid_graph(TFLiteContext *tfl_ctx, graph g)
+{
+    if (g >= MAX_GRAPHS_PER_INST) {
+        NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
+        return runtime_error;
+    }
+    if (tfl_ctx->models[g].model_pointer == NULL) {
+        NN_ERR_PRINTF("Context (model) non-initialized.");
+        return runtime_error;
+    }
+    if (tfl_ctx->models[g].model == NULL) {
+        NN_ERR_PRINTF("Context (tflite model) non-initialized.");
+        return runtime_error;
+    }
+    return success;
+}
 
-static char *model_pointer = NULL;
+static error
+is_valid_graph_execution_context(TFLiteContext *tfl_ctx,
+                                 graph_execution_context ctx)
+{
+    if (ctx >= MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
+        NN_ERR_PRINTF("Invalid graph execution context: %d >= %d", ctx,
+                      MAX_GRAPH_EXEC_CONTEXTS_PER_INST);
+        return runtime_error;
+    }
+    if (tfl_ctx->interpreters[ctx].interpreter == NULL) {
+        NN_ERR_PRINTF("Context (interpreter) non-initialized.");
+        return runtime_error;
+    }
+    return success;
+}
 
 /* WASI-NN (tensorflow) implementation */
 
 error
-tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
-                    execution_target target, graph *g)
+tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
+                    graph_encoding encoding, execution_target target, graph *g)
 {
-    if (model_pointer != NULL) {
-        wasm_runtime_free(model_pointer);
-        model_pointer = NULL;
-    }
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
 
     if (builder->size != 1) {
         NN_ERR_PRINTF("Unexpected builder format.");
@@ -51,39 +131,68 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
         return invalid_argument;
     }
 
+    error res;
+    if (success != (res = initialize_g(tfl_ctx, g)))
+        return res;
+
     uint32_t size = builder->buf[0].size;
 
-    model_pointer = (char *)wasm_runtime_malloc(size);
-    if (model_pointer == NULL) {
+    // Save model
+    tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size);
+    if (tfl_ctx->models[*g].model_pointer == NULL) {
         NN_ERR_PRINTF("Error when allocating memory for model.");
         return missing_memory;
     }
 
-    bh_memcpy_s(model_pointer, size, builder->buf[0].buf, size);
+    bh_memcpy_s(tfl_ctx->models[*g].model_pointer, size, builder->buf[0].buf,
+                size);
 
-    model = tflite::FlatBufferModel::BuildFromBuffer(model_pointer, size, NULL);
-    if (model == NULL) {
+    // Save model flatbuffer
+    tfl_ctx->models[*g].model =
+        std::move(tflite::FlatBufferModel::BuildFromBuffer(
+            tfl_ctx->models[*g].model_pointer, size, NULL));
+
+    if (tfl_ctx->models[*g].model == NULL) {
         NN_ERR_PRINTF("Loading model error.");
-        wasm_runtime_free(model_pointer);
-        model_pointer = NULL;
+        wasm_runtime_free(tfl_ctx->models[*g].model_pointer);
+        tfl_ctx->models[*g].model_pointer = NULL;
         return missing_memory;
     }
 
+    // Save target
+    tfl_ctx->models[*g].target = target;
+    return success;
+}
+
+error
+tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
+                                      graph_execution_context *ctx)
+{
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
+
+    error res;
+    if (success != (res = is_valid_graph(tfl_ctx, g)))
+        return res;
+
+    if (success != (res = initialize_graph_ctx(tfl_ctx, g, ctx)))
+        return res;
+
     // Build the interpreter with the InterpreterBuilder.
     tflite::ops::builtin::BuiltinOpResolver resolver;
-    tflite::InterpreterBuilder tflite_builder(*model, resolver);
-    tflite_builder(&interpreter);
-    if (interpreter == NULL) {
+    tflite::InterpreterBuilder tflite_builder(*tfl_ctx->models[g].model,
+                                              resolver);
+    tflite_builder(&tfl_ctx->interpreters[*ctx].interpreter);
+    if (tfl_ctx->interpreters[*ctx].interpreter == NULL) {
         NN_ERR_PRINTF("Error when generating the interpreter.");
-        wasm_runtime_free(model_pointer);
-        model_pointer = NULL;
         return missing_memory;
     }
 
     bool use_default = false;
-    switch (target) {
+    switch (tfl_ctx->models[g].target) {
         case gpu:
         {
+#if defined(WASI_NN_ENABLE_GPU)
+            NN_WARN_PRINTF("GPU enabled.");
             // https://www.tensorflow.org/lite/performance/gpu
             auto options = TfLiteGpuDelegateOptionsV2Default();
             options.inference_preference =
@@ -91,10 +200,16 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
             options.inference_priority1 =
                 TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
             auto *delegate = TfLiteGpuDelegateV2Create(&options);
-            if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
+            if (tfl_ctx->interpreters[*ctx]
+                    .interpreter->ModifyGraphWithDelegate(delegate)
+                != kTfLiteOk) {
                 NN_ERR_PRINTF("Error when enabling GPU delegate.");
                 use_default = true;
             }
+#else
+            NN_WARN_PRINTF("GPU not enabled.");
+            use_default = true;
+#endif
             break;
         }
         default:
@@ -103,36 +218,28 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
     if (use_default)
         NN_WARN_PRINTF("Default encoding is CPU.");
 
+    tfl_ctx->interpreters[*ctx].interpreter->AllocateTensors();
     return success;
 }
 
 error
-tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx)
+tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
+                         uint32_t index, tensor *input_tensor)
 {
-    if (interpreter == NULL) {
-        NN_ERR_PRINTF("Non-initialized interpreter.");
-        return runtime_error;
-    }
-    interpreter->AllocateTensors();
-    return success;
-}
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
 
-error
-tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
-                         tensor *input_tensor)
-{
-    if (interpreter == NULL) {
-        NN_ERR_PRINTF("Non-initialized interpreter.");
-        return runtime_error;
-    }
+    error res;
+    if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
+        return res;
 
-    uint32_t num_tensors = interpreter->inputs().size();
+    uint32_t num_tensors =
+        tfl_ctx->interpreters[ctx].interpreter->inputs().size();
     NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
     if (index + 1 > num_tensors) {
         return runtime_error;
     }
 
-    auto tensor = interpreter->input_tensor(index);
+    auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
     if (tensor == NULL) {
         NN_ERR_PRINTF("Missing memory");
         return missing_memory;
@@ -152,7 +259,9 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
         return invalid_argument;
     }
 
-    auto *input = interpreter->typed_input_tensor<float>(index);
+    auto *input =
+        tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
+            index);
     if (input == NULL)
         return missing_memory;
 
@@ -162,34 +271,38 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
 }
 
 error
-tensorflowlite_compute(graph_execution_context ctx)
+tensorflowlite_compute(void *tflite_ctx, graph_execution_context ctx)
 {
-    if (interpreter == NULL) {
-        NN_ERR_PRINTF("Non-initialized interpreter.");
-        return runtime_error;
-    }
-    interpreter->Invoke();
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
+
+    error res;
+    if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
+        return res;
+
+    tfl_ctx->interpreters[ctx].interpreter->Invoke();
     return success;
 }
 
 error
-tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
-                          tensor_data output_tensor,
+tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
+                          uint32_t index, tensor_data output_tensor,
                           uint32_t *output_tensor_size)
 {
-    if (interpreter == NULL) {
-        NN_ERR_PRINTF("Non-initialized interpreter.");
-        return runtime_error;
-    }
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
 
-    uint32_t num_output_tensors = interpreter->outputs().size();
+    error res;
+    if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
+        return res;
+
+    uint32_t num_output_tensors =
+        tfl_ctx->interpreters[ctx].interpreter->outputs().size();
     NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
 
     if (index + 1 > num_output_tensors) {
         return runtime_error;
     }
 
-    auto tensor = interpreter->output_tensor(index);
+    auto tensor = tfl_ctx->interpreters[ctx].interpreter->output_tensor(index);
     if (tensor == NULL) {
         NN_ERR_PRINTF("Missing memory");
         return missing_memory;
@@ -204,7 +317,9 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
         return missing_memory;
     }
 
-    float *tensor_f = interpreter->typed_output_tensor<float>(index);
+    float *tensor_f =
+        tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
+            index);
     for (uint32_t i = 0; i < model_tensor_size; ++i)
         NN_DBG_PRINTF("output: %f", tensor_f[i]);
 
@@ -215,20 +330,51 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
 }
 
 void
-tensorflowlite_destroy()
+tensorflowlite_initialize(void **tflite_ctx)
+{
+    TFLiteContext *tfl_ctx = new TFLiteContext();
+    if (tfl_ctx == NULL) {
+        NN_ERR_PRINTF("Error when allocating memory for tensorflowlite.");
+        return;
+    }
+
+    NN_DBG_PRINTF("Initializing models.");
+    tfl_ctx->current_models = 0;
+    for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
+        tfl_ctx->models[i].model_pointer = NULL;
+    }
+    NN_DBG_PRINTF("Initializing interpreters.");
+    tfl_ctx->current_interpreters = 0;
+
+    if (os_mutex_init(&tfl_ctx->g_lock) != 0) {
+        NN_ERR_PRINTF("Error while initializing the lock");
+    }
+
+    *tflite_ctx = (void *)tfl_ctx;
+}
+
+void
+tensorflowlite_destroy(void *tflite_ctx)
 {
     /*
-        TensorFlow Lite memory is man
+        TensorFlow Lite memory is internally managed by tensorflow
 
         Related issues:
         * https://github.com/tensorflow/tensorflow/issues/15880
     */
+    TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
+
     NN_DBG_PRINTF("Freeing memory.");
-    model.reset(nullptr);
-    model = NULL;
-    interpreter.reset(nullptr);
-    interpreter = NULL;
-    wasm_runtime_free(model_pointer);
-    model_pointer = NULL;
+    for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
+        tfl_ctx->models[i].model.reset();
+        if (tfl_ctx->models[i].model_pointer)
+            wasm_runtime_free(tfl_ctx->models[i].model_pointer);
+        tfl_ctx->models[i].model_pointer = NULL;
+    }
+    for (int i = 0; i < MAX_GRAPH_EXEC_CONTEXTS_PER_INST; ++i) {
+        tfl_ctx->interpreters[i].interpreter.reset();
+    }
+    os_mutex_destroy(&tfl_ctx->g_lock);
+    delete tfl_ctx;
     NN_DBG_PRINTF("Memory free'd.");
 }

+ 13 - 9
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp

@@ -13,26 +13,30 @@ extern "C" {
 #endif
 
 error
-tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
-                    execution_target target, graph *g);
+tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
+                    graph_encoding encoding, execution_target target, graph *g);
 
 error
-tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx);
+tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
+                                      graph_execution_context *ctx);
 
 error
-tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
-                         tensor *input_tensor);
+tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
+                         uint32_t index, tensor *input_tensor);
 
 error
-tensorflowlite_compute(graph_execution_context ctx);
+tensorflowlite_compute(void *tflite_ctx, graph_execution_context ctx);
 
 error
-tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
-                          tensor_data output_tensor,
+tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
+                          uint32_t index, tensor_data output_tensor,
                           uint32_t *output_tensor_size);
 
 void
-tensorflowlite_destroy();
+tensorflowlite_initialize(void **tflite_ctx);
+
+void
+tensorflowlite_destroy(void *tflite_ctx);
 
 #ifdef __cplusplus
 }

+ 0 - 22
core/iwasm/libraries/wasi-nn/test/Dockerfile.base

@@ -1,22 +0,0 @@
-# Copyright (C) 2019 Intel Corporation.  All rights reserved.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-FROM ubuntu:20.04 AS base
-
-ENV DEBIAN_FRONTEND=noninteractive
-
-RUN apt-get update && apt-get install -y \
-    cmake build-essential git
-
-WORKDIR /home/wamr
-
-COPY . .
-
-WORKDIR /home/wamr/core/iwasm/libraries/wasi-nn/test/build
-
-RUN cmake \
-  -DWAMR_BUILD_WASI_NN=1 \
-  -DTFLITE_ENABLE_GPU=ON \
-  ..
-
-RUN make -j $(grep -c ^processor /proc/cpuinfo)

+ 21 - 2
core/iwasm/libraries/wasi-nn/test/Dockerfile.cpu

@@ -1,8 +1,27 @@
 # Copyright (C) 2019 Intel Corporation.  All rights reserved.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-FROM ubuntu:20.04
+FROM ubuntu:20.04 AS base
 
-COPY --from=wasi-nn-base /home/wamr/core/iwasm/libraries/wasi-nn/test/build/iwasm /run/iwasm
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update && apt-get install -y \
+    cmake build-essential git
+
+WORKDIR /home/wamr
+
+COPY . .
+
+WORKDIR /home/wamr/core/iwasm/libraries/wasi-nn/test/build
+
+RUN cmake \
+  -DWAMR_BUILD_WASI_NN=1 \
+  ..
+
+RUN make -j $(grep -c ^processor /proc/cpuinfo)
+
+FROM ubuntu:22.04
+
+COPY --from=base /home/wamr/core/iwasm/libraries/wasi-nn/test/build/iwasm /run/iwasm
 
 ENTRYPOINT [ "/run/iwasm" ]

+ 21 - 1
core/iwasm/libraries/wasi-nn/test/Dockerfile.nvidia-gpu

@@ -1,6 +1,26 @@
 # Copyright (C) 2019 Intel Corporation.  All rights reserved.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+FROM ubuntu:20.04 AS base
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update && apt-get install -y \
+    cmake build-essential git
+
+WORKDIR /home/wamr
+
+COPY . .
+
+WORKDIR /home/wamr/core/iwasm/libraries/wasi-nn/test/build
+
+RUN cmake \
+  -DWAMR_BUILD_WASI_NN=1 \
+  -DWASI_NN_ENABLE_GPU=1 \
+  ..
+
+RUN make -j $(grep -c ^processor /proc/cpuinfo)
+
 FROM nvidia/cuda:11.3.0-runtime-ubuntu20.04
 
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -15,6 +35,6 @@ RUN mkdir -p /etc/OpenCL/vendors && \
 ENV NVIDIA_VISIBLE_DEVICES=all
 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
 
-COPY --from=wasi-nn-base /home/wamr/core/iwasm/libraries/wasi-nn/test/build/iwasm /run/iwasm
+COPY --from=base /home/wamr/core/iwasm/libraries/wasi-nn/test/build/iwasm /run/iwasm
 
 ENTRYPOINT [ "/run/iwasm" ]

+ 3 - 2
core/iwasm/libraries/wasi-nn/test/build.sh

@@ -7,8 +7,9 @@
     -Wl,--allow-undefined \
     -Wl,--strip-all,--no-entry \
     --sysroot=/opt/wasi-sdk/share/wasi-sysroot \
-    -I.. \
-    -o test_tensorflow.wasm test_tensorflow.c
+    -I.. -I../src/utils \
+    -o test_tensorflow.wasm \
+    test_tensorflow.c utils.c
 
 # TFLite models to use in the tests
 

+ 13 - 186
core/iwasm/libraries/wasi-nn/test/test_tensorflow.c

@@ -5,185 +5,12 @@
 
 #include <stdio.h>
 #include <stdlib.h>
+#include <assert.h>
 #include <string.h>
-#include <stdint.h>
 #include <math.h>
-#include <assert.h>
-#include "wasi_nn.h"
-
-#include <fcntl.h>
-#include <errno.h>
-
-#define MAX_MODEL_SIZE 85000000
-#define MAX_OUTPUT_TENSOR_SIZE 200
-#define INPUT_TENSOR_DIMS 4
-#define EPSILON 1e-8
-
-typedef struct {
-    float *input_tensor;
-    uint32_t *dim;
-    uint32_t elements;
-} input_info;
-
-// WASI-NN wrappers
-
-error
-wasm_load(char *model_name, graph *g, execution_target target)
-{
-    FILE *pFile = fopen(model_name, "r");
-    if (pFile == NULL)
-        return invalid_argument;
-
-    uint8_t *buffer;
-    size_t result;
-
-    // allocate memory to contain the whole file:
-    buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
-    if (buffer == NULL) {
-        fclose(pFile);
-        return missing_memory;
-    }
-
-    result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
-    if (result <= 0) {
-        fclose(pFile);
-        free(buffer);
-        return missing_memory;
-    }
-
-    graph_builder_array arr;
-
-    arr.size = 1;
-    arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
-    if (arr.buf == NULL) {
-        fclose(pFile);
-        free(buffer);
-        return missing_memory;
-    }
-
-    arr.buf[0].size = result;
-    arr.buf[0].buf = buffer;
-
-    error res = load(&arr, tensorflowlite, target, g);
-
-    fclose(pFile);
-    free(buffer);
-    free(arr.buf);
-    return res;
-}
-
-error
-wasm_init_execution_context(graph g, graph_execution_context *ctx)
-{
-    return init_execution_context(g, ctx);
-}
-
-error
-wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
-{
-    tensor_dimensions dims;
-    dims.size = INPUT_TENSOR_DIMS;
-    dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t));
-    if (dims.buf == NULL)
-        return missing_memory;
-
-    tensor tensor;
-    tensor.dimensions = &dims;
-    for (int i = 0; i < tensor.dimensions->size; ++i)
-        tensor.dimensions->buf[i] = dim[i];
-    tensor.type = fp32;
-    tensor.data = (uint8_t *)input_tensor;
-    error err = set_input(ctx, 0, &tensor);
-
-    free(dims.buf);
-    return err;
-}
-
-error
-wasm_compute(graph_execution_context ctx)
-{
-    return compute(ctx);
-}
-
-error
-wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
-                uint32_t *out_size)
-{
-    return get_output(ctx, index, (uint8_t *)out_tensor, out_size);
-}
-
-// Inference
-
-float *
-run_inference(execution_target target, float *input, uint32_t *input_size,
-              uint32_t *output_size, char *model_name,
-              uint32_t num_output_tensors)
-{
-    graph graph;
-    if (wasm_load(model_name, &graph, target) != success) {
-        fprintf(stderr, "Error when loading model.");
-        exit(1);
-    }
-
-    graph_execution_context ctx;
-    if (wasm_init_execution_context(graph, &ctx) != success) {
-        fprintf(stderr, "Error when initialixing execution context.");
-        exit(1);
-    }
-
-    if (wasm_set_input(ctx, input, input_size) != success) {
-        fprintf(stderr, "Error when setting input tensor.");
-        exit(1);
-    }
-
-    if (wasm_compute(ctx) != success) {
-        fprintf(stderr, "Error when running inference.");
-        exit(1);
-    }
-
-    float *out_tensor = (float *)malloc(sizeof(float) * MAX_OUTPUT_TENSOR_SIZE);
-    if (out_tensor == NULL) {
-        fprintf(stderr, "Error when allocating memory for output tensor.");
-        exit(1);
-    }
-
-    uint32_t offset = 0;
-    for (int i = 0; i < num_output_tensors; ++i) {
-        *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size;
-        if (wasm_get_output(ctx, i, &out_tensor[offset], output_size)
-            != success) {
-            fprintf(stderr, "Error when getting output .");
-            exit(1);
-        }
-
-        offset += *output_size;
-    }
-    *output_size = offset;
-    return out_tensor;
-}
-
-// UTILS
-
-input_info
-create_input(int *dims)
-{
-    input_info input = { .dim = NULL, .input_tensor = NULL, .elements = 1 };
-
-    input.dim = malloc(INPUT_TENSOR_DIMS * sizeof(uint32_t));
-    if (input.dim)
-        for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) {
-            input.dim[i] = dims[i];
-            input.elements *= dims[i];
-        }
-
-    input.input_tensor = malloc(input.elements * sizeof(float));
-    for (int i = 0; i < input.elements; ++i)
-        input.input_tensor[i] = i;
-
-    return input;
-}
 
-// TESTS
+#include "utils.h"
+#include "logger.h"
 
 void
 test_sum(execution_target target)
@@ -215,7 +42,7 @@ test_max(execution_target target)
 
     assert(output_size == 1);
     assert(fabs(output[0] - 24.0) < EPSILON);
-    printf("Result: max is %f\n", output[0]);
+    NN_INFO_PRINTF("Result: max is %f", output[0]);
 
     free(input.dim);
     free(input.input_tensor);
@@ -235,7 +62,7 @@ test_average(execution_target target)
 
     assert(output_size == 1);
     assert(fabs(output[0] - 12.0) < EPSILON);
-    printf("Result: average is %f\n", output[0]);
+    NN_INFO_PRINTF("Result: average is %f", output[0]);
 
     free(input.dim);
     free(input.input_tensor);
@@ -291,7 +118,7 @@ main()
 {
     char *env = getenv("TARGET");
     if (env == NULL) {
-        printf("Usage:\n--env=\"TARGET=[cpu|gpu]\"\n");
+        NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\"");
         return 1;
     }
     execution_target target;
@@ -300,20 +127,20 @@ main()
     else if (strcmp(env, "gpu") == 0)
         target = gpu;
     else {
-        printf("Wrong target!");
+        NN_ERR_PRINTF("Wrong target!");
         return 1;
     }
-    printf("################### Testing sum...\n");
+    NN_INFO_PRINTF("################### Testing sum...");
     test_sum(target);
-    printf("################### Testing max...\n");
+    NN_INFO_PRINTF("################### Testing max...");
     test_max(target);
-    printf("################### Testing average...\n");
+    NN_INFO_PRINTF("################### Testing average...");
     test_average(target);
-    printf("################### Testing multiple dimensions...\n");
+    NN_INFO_PRINTF("################### Testing multiple dimensions...");
     test_mult_dimensions(target);
-    printf("################### Testing multiple outputs...\n");
+    NN_INFO_PRINTF("################### Testing multiple outputs...");
     test_mult_outputs(target);
 
-    printf("Tests: passed!\n");
+    NN_INFO_PRINTF("Tests: passed!");
     return 0;
 }

+ 162 - 0
core/iwasm/libraries/wasi-nn/test/utils.c

@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#include "utils.h"
+#include "logger.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+error
+wasm_load(char *model_name, graph *g, execution_target target)
+{
+    FILE *pFile = fopen(model_name, "r");
+    if (pFile == NULL)
+        return invalid_argument;
+
+    uint8_t *buffer;
+    size_t result;
+
+    // allocate memory to contain the whole file:
+    buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
+    if (buffer == NULL) {
+        fclose(pFile);
+        return missing_memory;
+    }
+
+    result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
+    if (result <= 0) {
+        fclose(pFile);
+        free(buffer);
+        return missing_memory;
+    }
+
+    graph_builder_array arr;
+
+    arr.size = 1;
+    arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
+    if (arr.buf == NULL) {
+        fclose(pFile);
+        free(buffer);
+        return missing_memory;
+    }
+
+    arr.buf[0].size = result;
+    arr.buf[0].buf = buffer;
+
+    error res = load(&arr, tensorflowlite, target, g);
+
+    fclose(pFile);
+    free(buffer);
+    free(arr.buf);
+    return res;
+}
+
+error
+wasm_init_execution_context(graph g, graph_execution_context *ctx)
+{
+    return init_execution_context(g, ctx);
+}
+
+error
+wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
+{
+    tensor_dimensions dims;
+    dims.size = INPUT_TENSOR_DIMS;
+    dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t));
+    if (dims.buf == NULL)
+        return missing_memory;
+
+    tensor tensor;
+    tensor.dimensions = &dims;
+    for (int i = 0; i < tensor.dimensions->size; ++i)
+        tensor.dimensions->buf[i] = dim[i];
+    tensor.type = fp32;
+    tensor.data = (uint8_t *)input_tensor;
+    error err = set_input(ctx, 0, &tensor);
+
+    free(dims.buf);
+    return err;
+}
+
+error
+wasm_compute(graph_execution_context ctx)
+{
+    return compute(ctx);
+}
+
+error
+wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
+                uint32_t *out_size)
+{
+    return get_output(ctx, index, (uint8_t *)out_tensor, out_size);
+}
+
+float *
+run_inference(execution_target target, float *input, uint32_t *input_size,
+              uint32_t *output_size, char *model_name,
+              uint32_t num_output_tensors)
+{
+    graph graph;
+    if (wasm_load(model_name, &graph, target) != success) {
+        NN_ERR_PRINTF("Error when loading model.");
+        exit(1);
+    }
+
+    graph_execution_context ctx;
+    if (wasm_init_execution_context(graph, &ctx) != success) {
+        NN_ERR_PRINTF("Error when initialixing execution context.");
+        exit(1);
+    }
+
+    if (wasm_set_input(ctx, input, input_size) != success) {
+        NN_ERR_PRINTF("Error when setting input tensor.");
+        exit(1);
+    }
+
+    if (wasm_compute(ctx) != success) {
+        NN_ERR_PRINTF("Error when running inference.");
+        exit(1);
+    }
+
+    float *out_tensor = (float *)malloc(sizeof(float) * MAX_OUTPUT_TENSOR_SIZE);
+    if (out_tensor == NULL) {
+        NN_ERR_PRINTF("Error when allocating memory for output tensor.");
+        exit(1);
+    }
+
+    uint32_t offset = 0;
+    for (int i = 0; i < num_output_tensors; ++i) {
+        *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size;
+        if (wasm_get_output(ctx, i, &out_tensor[offset], output_size)
+            != success) {
+            NN_ERR_PRINTF("Error when getting output.");
+            exit(1);
+        }
+
+        offset += *output_size;
+    }
+    *output_size = offset;
+    return out_tensor;
+}
+
+input_info
+create_input(int *dims)
+{
+    input_info input = { .dim = NULL, .input_tensor = NULL, .elements = 1 };
+
+    input.dim = malloc(INPUT_TENSOR_DIMS * sizeof(uint32_t));
+    if (input.dim)
+        for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) {
+            input.dim[i] = dims[i];
+            input.elements *= dims[i];
+        }
+
+    input.input_tensor = malloc(input.elements * sizeof(float));
+    for (int i = 0; i < input.elements; ++i)
+        input.input_tensor[i] = i;
+
+    return input;
+}

+ 52 - 0
core/iwasm/libraries/wasi-nn/test/utils.h

@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2019 Intel Corporation.  All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+#ifndef WASI_NN_UTILS
+#define WASI_NN_UTILS
+
+#include <stdint.h>
+
+#include "wasi_nn.h"
+
+#define MAX_MODEL_SIZE 85000000
+#define MAX_OUTPUT_TENSOR_SIZE 200
+#define INPUT_TENSOR_DIMS 4
+#define EPSILON 1e-8
+
+typedef struct {
+    float *input_tensor;
+    uint32_t *dim;
+    uint32_t elements;
+} input_info;
+
+/* wasi-nn wrappers */
+
+error
+wasm_load(char *model_name, graph *g, execution_target target);
+
+error
+wasm_init_execution_context(graph g, graph_execution_context *ctx);
+
+error
+wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim);
+
+error
+wasm_compute(graph_execution_context ctx);
+
+error
+wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
+                uint32_t *out_size);
+
+/* Utils */
+
+float *
+run_inference(execution_target target, float *input, uint32_t *input_size,
+              uint32_t *output_size, char *model_name,
+              uint32_t num_output_tensors);
+
+input_info
+create_input(int *dims);
+
+#endif

+ 0 - 2
core/iwasm/libraries/wasi-nn/wasi_nn.cmake

@@ -3,8 +3,6 @@
 
 set (WASI_NN_DIR ${CMAKE_CURRENT_LIST_DIR})
 
-add_definitions (-DWASM_ENABLE_WASI_NN=1)
-
 include_directories (${WASI_NN_DIR})
 include_directories (${WASI_NN_DIR}/src)
 include_directories (${WASI_NN_DIR}/src/utils)