|
|
@@ -16,25 +16,105 @@
|
|
|
#include <tensorflow/lite/model.h>
|
|
|
#include <tensorflow/lite/optional_debug_tools.h>
|
|
|
#include <tensorflow/lite/error_reporter.h>
|
|
|
-#include <tensorflow/lite/delegates/gpu/delegate.h>
|
|
|
|
|
|
-/* Global variables */
|
|
|
+#if defined(WASI_NN_ENABLE_GPU)
|
|
|
+#include <tensorflow/lite/delegates/gpu/delegate.h>
|
|
|
+#endif
|
|
|
+
|
|
|
+/* Maximum number of graphs per WASM instance */
|
|
|
+#define MAX_GRAPHS_PER_INST 10
|
|
|
+/* Maximum number of graph execution context per WASM instance*/
|
|
|
+#define MAX_GRAPH_EXEC_CONTEXTS_PER_INST 10
|
|
|
+
|
|
|
+typedef struct {
|
|
|
+ std::unique_ptr<tflite::Interpreter> interpreter;
|
|
|
+} Interpreter;
|
|
|
+
|
|
|
+typedef struct {
|
|
|
+ char *model_pointer;
|
|
|
+ std::unique_ptr<tflite::FlatBufferModel> model;
|
|
|
+ execution_target target;
|
|
|
+} Model;
|
|
|
+
|
|
|
+typedef struct {
|
|
|
+ uint32_t current_models;
|
|
|
+ Model models[MAX_GRAPHS_PER_INST];
|
|
|
+ uint32_t current_interpreters;
|
|
|
+ Interpreter interpreters[MAX_GRAPH_EXEC_CONTEXTS_PER_INST];
|
|
|
+ korp_mutex g_lock;
|
|
|
+} TFLiteContext;
|
|
|
+
|
|
|
+/* Utils */
|
|
|
+
|
|
|
+static error
|
|
|
+initialize_g(TFLiteContext *tfl_ctx, graph *g)
|
|
|
+{
|
|
|
+ os_mutex_lock(&tfl_ctx->g_lock);
|
|
|
+ if (tfl_ctx->current_models == MAX_GRAPHS_PER_INST) {
|
|
|
+ os_mutex_unlock(&tfl_ctx->g_lock);
|
|
|
+ NN_ERR_PRINTF("Excedded max graphs per WASM instance");
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ *g = tfl_ctx->current_models++;
|
|
|
+ os_mutex_unlock(&tfl_ctx->g_lock);
|
|
|
+ return success;
|
|
|
+}
|
|
|
+static error
|
|
|
+initialize_graph_ctx(TFLiteContext *tfl_ctx, graph g,
|
|
|
+ graph_execution_context *ctx)
|
|
|
+{
|
|
|
+ os_mutex_lock(&tfl_ctx->g_lock);
|
|
|
+ if (tfl_ctx->current_interpreters == MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
|
|
|
+ os_mutex_unlock(&tfl_ctx->g_lock);
|
|
|
+ NN_ERR_PRINTF("Excedded max graph execution context per WASM instance");
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ *ctx = tfl_ctx->current_interpreters++;
|
|
|
+ os_mutex_unlock(&tfl_ctx->g_lock);
|
|
|
+ return success;
|
|
|
+}
|
|
|
|
|
|
-static std::unique_ptr<tflite::Interpreter> interpreter;
|
|
|
-static std::unique_ptr<tflite::FlatBufferModel> model;
|
|
|
+static error
|
|
|
+is_valid_graph(TFLiteContext *tfl_ctx, graph g)
|
|
|
+{
|
|
|
+ if (g >= MAX_GRAPHS_PER_INST) {
|
|
|
+ NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ if (tfl_ctx->models[g].model_pointer == NULL) {
|
|
|
+ NN_ERR_PRINTF("Context (model) non-initialized.");
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ if (tfl_ctx->models[g].model == NULL) {
|
|
|
+ NN_ERR_PRINTF("Context (tflite model) non-initialized.");
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ return success;
|
|
|
+}
|
|
|
|
|
|
-static char *model_pointer = NULL;
|
|
|
+static error
|
|
|
+is_valid_graph_execution_context(TFLiteContext *tfl_ctx,
|
|
|
+ graph_execution_context ctx)
|
|
|
+{
|
|
|
+ if (ctx >= MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
|
|
|
+ NN_ERR_PRINTF("Invalid graph execution context: %d >= %d", ctx,
|
|
|
+ MAX_GRAPH_EXEC_CONTEXTS_PER_INST);
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ if (tfl_ctx->interpreters[ctx].interpreter == NULL) {
|
|
|
+ NN_ERR_PRINTF("Context (interpreter) non-initialized.");
|
|
|
+ return runtime_error;
|
|
|
+ }
|
|
|
+ return success;
|
|
|
+}
|
|
|
|
|
|
/* WASI-NN (tensorflow) implementation */
|
|
|
|
|
|
error
|
|
|
-tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
|
|
- execution_target target, graph *g)
|
|
|
+tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
|
|
|
+ graph_encoding encoding, execution_target target, graph *g)
|
|
|
{
|
|
|
- if (model_pointer != NULL) {
|
|
|
- wasm_runtime_free(model_pointer);
|
|
|
- model_pointer = NULL;
|
|
|
- }
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
|
|
|
if (builder->size != 1) {
|
|
|
NN_ERR_PRINTF("Unexpected builder format.");
|
|
|
@@ -51,39 +131,68 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
|
|
return invalid_argument;
|
|
|
}
|
|
|
|
|
|
+ error res;
|
|
|
+ if (success != (res = initialize_g(tfl_ctx, g)))
|
|
|
+ return res;
|
|
|
+
|
|
|
uint32_t size = builder->buf[0].size;
|
|
|
|
|
|
- model_pointer = (char *)wasm_runtime_malloc(size);
|
|
|
- if (model_pointer == NULL) {
|
|
|
+ // Save model
|
|
|
+ tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size);
|
|
|
+ if (tfl_ctx->models[*g].model_pointer == NULL) {
|
|
|
NN_ERR_PRINTF("Error when allocating memory for model.");
|
|
|
return missing_memory;
|
|
|
}
|
|
|
|
|
|
- bh_memcpy_s(model_pointer, size, builder->buf[0].buf, size);
|
|
|
+ bh_memcpy_s(tfl_ctx->models[*g].model_pointer, size, builder->buf[0].buf,
|
|
|
+ size);
|
|
|
|
|
|
- model = tflite::FlatBufferModel::BuildFromBuffer(model_pointer, size, NULL);
|
|
|
- if (model == NULL) {
|
|
|
+ // Save model flatbuffer
|
|
|
+ tfl_ctx->models[*g].model =
|
|
|
+ std::move(tflite::FlatBufferModel::BuildFromBuffer(
|
|
|
+ tfl_ctx->models[*g].model_pointer, size, NULL));
|
|
|
+
|
|
|
+ if (tfl_ctx->models[*g].model == NULL) {
|
|
|
NN_ERR_PRINTF("Loading model error.");
|
|
|
- wasm_runtime_free(model_pointer);
|
|
|
- model_pointer = NULL;
|
|
|
+ wasm_runtime_free(tfl_ctx->models[*g].model_pointer);
|
|
|
+ tfl_ctx->models[*g].model_pointer = NULL;
|
|
|
return missing_memory;
|
|
|
}
|
|
|
|
|
|
+ // Save target
|
|
|
+ tfl_ctx->models[*g].target = target;
|
|
|
+ return success;
|
|
|
+}
|
|
|
+
|
|
|
+error
|
|
|
+tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
|
|
|
+ graph_execution_context *ctx)
|
|
|
+{
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
+
|
|
|
+ error res;
|
|
|
+ if (success != (res = is_valid_graph(tfl_ctx, g)))
|
|
|
+ return res;
|
|
|
+
|
|
|
+ if (success != (res = initialize_graph_ctx(tfl_ctx, g, ctx)))
|
|
|
+ return res;
|
|
|
+
|
|
|
// Build the interpreter with the InterpreterBuilder.
|
|
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
|
|
- tflite::InterpreterBuilder tflite_builder(*model, resolver);
|
|
|
- tflite_builder(&interpreter);
|
|
|
- if (interpreter == NULL) {
|
|
|
+ tflite::InterpreterBuilder tflite_builder(*tfl_ctx->models[g].model,
|
|
|
+ resolver);
|
|
|
+ tflite_builder(&tfl_ctx->interpreters[*ctx].interpreter);
|
|
|
+ if (tfl_ctx->interpreters[*ctx].interpreter == NULL) {
|
|
|
NN_ERR_PRINTF("Error when generating the interpreter.");
|
|
|
- wasm_runtime_free(model_pointer);
|
|
|
- model_pointer = NULL;
|
|
|
return missing_memory;
|
|
|
}
|
|
|
|
|
|
bool use_default = false;
|
|
|
- switch (target) {
|
|
|
+ switch (tfl_ctx->models[g].target) {
|
|
|
case gpu:
|
|
|
{
|
|
|
+#if defined(WASI_NN_ENABLE_GPU)
|
|
|
+ NN_WARN_PRINTF("GPU enabled.");
|
|
|
// https://www.tensorflow.org/lite/performance/gpu
|
|
|
auto options = TfLiteGpuDelegateOptionsV2Default();
|
|
|
options.inference_preference =
|
|
|
@@ -91,10 +200,16 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
|
|
options.inference_priority1 =
|
|
|
TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
|
|
|
auto *delegate = TfLiteGpuDelegateV2Create(&options);
|
|
|
- if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
|
|
|
+ if (tfl_ctx->interpreters[*ctx]
|
|
|
+ .interpreter->ModifyGraphWithDelegate(delegate)
|
|
|
+ != kTfLiteOk) {
|
|
|
NN_ERR_PRINTF("Error when enabling GPU delegate.");
|
|
|
use_default = true;
|
|
|
}
|
|
|
+#else
|
|
|
+ NN_WARN_PRINTF("GPU not enabled.");
|
|
|
+ use_default = true;
|
|
|
+#endif
|
|
|
break;
|
|
|
}
|
|
|
default:
|
|
|
@@ -103,36 +218,28 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
|
|
if (use_default)
|
|
|
NN_WARN_PRINTF("Default encoding is CPU.");
|
|
|
|
|
|
+ tfl_ctx->interpreters[*ctx].interpreter->AllocateTensors();
|
|
|
return success;
|
|
|
}
|
|
|
|
|
|
error
|
|
|
-tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx)
|
|
|
+tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
|
|
|
+ uint32_t index, tensor *input_tensor)
|
|
|
{
|
|
|
- if (interpreter == NULL) {
|
|
|
- NN_ERR_PRINTF("Non-initialized interpreter.");
|
|
|
- return runtime_error;
|
|
|
- }
|
|
|
- interpreter->AllocateTensors();
|
|
|
- return success;
|
|
|
-}
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
|
|
|
-error
|
|
|
-tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
|
|
- tensor *input_tensor)
|
|
|
-{
|
|
|
- if (interpreter == NULL) {
|
|
|
- NN_ERR_PRINTF("Non-initialized interpreter.");
|
|
|
- return runtime_error;
|
|
|
- }
|
|
|
+ error res;
|
|
|
+ if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
|
|
+ return res;
|
|
|
|
|
|
- uint32_t num_tensors = interpreter->inputs().size();
|
|
|
+ uint32_t num_tensors =
|
|
|
+ tfl_ctx->interpreters[ctx].interpreter->inputs().size();
|
|
|
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
|
|
|
if (index + 1 > num_tensors) {
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- auto tensor = interpreter->input_tensor(index);
|
|
|
+ auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
|
|
|
if (tensor == NULL) {
|
|
|
NN_ERR_PRINTF("Missing memory");
|
|
|
return missing_memory;
|
|
|
@@ -152,7 +259,9 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
|
|
return invalid_argument;
|
|
|
}
|
|
|
|
|
|
- auto *input = interpreter->typed_input_tensor<float>(index);
|
|
|
+ auto *input =
|
|
|
+ tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
|
|
+ index);
|
|
|
if (input == NULL)
|
|
|
return missing_memory;
|
|
|
|
|
|
@@ -162,34 +271,38 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
|
|
}
|
|
|
|
|
|
error
|
|
|
-tensorflowlite_compute(graph_execution_context ctx)
|
|
|
+tensorflowlite_compute(void *tflite_ctx, graph_execution_context ctx)
|
|
|
{
|
|
|
- if (interpreter == NULL) {
|
|
|
- NN_ERR_PRINTF("Non-initialized interpreter.");
|
|
|
- return runtime_error;
|
|
|
- }
|
|
|
- interpreter->Invoke();
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
+
|
|
|
+ error res;
|
|
|
+ if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
|
|
+ return res;
|
|
|
+
|
|
|
+ tfl_ctx->interpreters[ctx].interpreter->Invoke();
|
|
|
return success;
|
|
|
}
|
|
|
|
|
|
error
|
|
|
-tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
|
|
- tensor_data output_tensor,
|
|
|
+tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
|
|
+ uint32_t index, tensor_data output_tensor,
|
|
|
uint32_t *output_tensor_size)
|
|
|
{
|
|
|
- if (interpreter == NULL) {
|
|
|
- NN_ERR_PRINTF("Non-initialized interpreter.");
|
|
|
- return runtime_error;
|
|
|
- }
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
|
|
|
- uint32_t num_output_tensors = interpreter->outputs().size();
|
|
|
+ error res;
|
|
|
+ if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
|
|
+ return res;
|
|
|
+
|
|
|
+ uint32_t num_output_tensors =
|
|
|
+ tfl_ctx->interpreters[ctx].interpreter->outputs().size();
|
|
|
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
|
|
|
|
|
if (index + 1 > num_output_tensors) {
|
|
|
return runtime_error;
|
|
|
}
|
|
|
|
|
|
- auto tensor = interpreter->output_tensor(index);
|
|
|
+ auto tensor = tfl_ctx->interpreters[ctx].interpreter->output_tensor(index);
|
|
|
if (tensor == NULL) {
|
|
|
NN_ERR_PRINTF("Missing memory");
|
|
|
return missing_memory;
|
|
|
@@ -204,7 +317,9 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
|
|
return missing_memory;
|
|
|
}
|
|
|
|
|
|
- float *tensor_f = interpreter->typed_output_tensor<float>(index);
|
|
|
+ float *tensor_f =
|
|
|
+ tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
|
|
+ index);
|
|
|
for (uint32_t i = 0; i < model_tensor_size; ++i)
|
|
|
NN_DBG_PRINTF("output: %f", tensor_f[i]);
|
|
|
|
|
|
@@ -215,20 +330,51 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
|
|
}
|
|
|
|
|
|
void
|
|
|
-tensorflowlite_destroy()
|
|
|
+tensorflowlite_initialize(void **tflite_ctx)
|
|
|
+{
|
|
|
+ TFLiteContext *tfl_ctx = new TFLiteContext();
|
|
|
+ if (tfl_ctx == NULL) {
|
|
|
+ NN_ERR_PRINTF("Error when allocating memory for tensorflowlite.");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ NN_DBG_PRINTF("Initializing models.");
|
|
|
+ tfl_ctx->current_models = 0;
|
|
|
+ for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
|
|
+ tfl_ctx->models[i].model_pointer = NULL;
|
|
|
+ }
|
|
|
+ NN_DBG_PRINTF("Initializing interpreters.");
|
|
|
+ tfl_ctx->current_interpreters = 0;
|
|
|
+
|
|
|
+ if (os_mutex_init(&tfl_ctx->g_lock) != 0) {
|
|
|
+ NN_ERR_PRINTF("Error while initializing the lock");
|
|
|
+ }
|
|
|
+
|
|
|
+ *tflite_ctx = (void *)tfl_ctx;
|
|
|
+}
|
|
|
+
|
|
|
+void
|
|
|
+tensorflowlite_destroy(void *tflite_ctx)
|
|
|
{
|
|
|
/*
|
|
|
- TensorFlow Lite memory is man
|
|
|
+ TensorFlow Lite memory is internally managed by tensorflow
|
|
|
|
|
|
Related issues:
|
|
|
* https://github.com/tensorflow/tensorflow/issues/15880
|
|
|
*/
|
|
|
+ TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
|
|
+
|
|
|
NN_DBG_PRINTF("Freeing memory.");
|
|
|
- model.reset(nullptr);
|
|
|
- model = NULL;
|
|
|
- interpreter.reset(nullptr);
|
|
|
- interpreter = NULL;
|
|
|
- wasm_runtime_free(model_pointer);
|
|
|
- model_pointer = NULL;
|
|
|
+ for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
|
|
+ tfl_ctx->models[i].model.reset();
|
|
|
+ if (tfl_ctx->models[i].model_pointer)
|
|
|
+ wasm_runtime_free(tfl_ctx->models[i].model_pointer);
|
|
|
+ tfl_ctx->models[i].model_pointer = NULL;
|
|
|
+ }
|
|
|
+ for (int i = 0; i < MAX_GRAPH_EXEC_CONTEXTS_PER_INST; ++i) {
|
|
|
+ tfl_ctx->interpreters[i].interpreter.reset();
|
|
|
+ }
|
|
|
+ os_mutex_destroy(&tfl_ctx->g_lock);
|
|
|
+ delete tfl_ctx;
|
|
|
NN_DBG_PRINTF("Memory free'd.");
|
|
|
}
|