Bläddra i källkod

add load_by_name in wasi-nn (#4298)

hongxia 7 månader sedan
förälder
incheckning
aa1ff778b9

+ 1 - 1
core/iwasm/libraries/wasi-nn/include/wasi_nn.h

@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
     __attribute__((import_module("wasi_nn")));
 
 wasi_nn_error
-load_by_name(const char *name, graph *g)
+load_by_name(const char *name, uint32_t name_len, graph *g)
     __attribute__((import_module("wasi_nn")));
 
 /**

+ 1 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
     REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
 #else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     REG_NATIVE_FUNC(load, "(*ii*)i"),
+    REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
     REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
     REG_NATIVE_FUNC(set_input, "(ii*)i"),
     REG_NATIVE_FUNC(compute, "(i)i"),

+ 18 - 23
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

@@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
         NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
         return runtime_error;
     }
-    if (tfl_ctx->models[g].model_pointer == NULL) {
-        NN_ERR_PRINTF("Context (model) non-initialized.");
-        return runtime_error;
-    }
     if (tfl_ctx->models[g].model == NULL) {
-        NN_ERR_PRINTF("Context (tflite model) non-initialized.");
+        NN_ERR_PRINTF("Context (model) non-initialized.");
         return runtime_error;
     }
     return success;
@@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx)
     NN_DBG_PRINTF("Freeing memory.");
     for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
         tfl_ctx->models[i].model.reset();
-        if (tfl_ctx->models[i].model_pointer) {
-            if (tfl_ctx->delegate) {
-                switch (tfl_ctx->models[i].target) {
-                    case gpu:
-                    {
+        if (tfl_ctx->delegate) {
+            switch (tfl_ctx->models[i].target) {
+                case gpu:
+                {
 #if WASM_ENABLE_WASI_NN_GPU != 0
-                        TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
+                    TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
 #else
-                        NN_ERR_PRINTF("GPU delegate delete but not enabled.");
+                    NN_ERR_PRINTF("GPU delegate delete but not enabled.");
 #endif
-                        break;
-                    }
-                    case tpu:
-                    {
+                    break;
+                }
+                case tpu:
+                {
 #if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
-                        TfLiteExternalDelegateDelete(tfl_ctx->delegate);
+                    TfLiteExternalDelegateDelete(tfl_ctx->delegate);
 #else
-                        NN_ERR_PRINTF(
-                            "External delegate delete but not enabled.");
+                    NN_ERR_PRINTF("External delegate delete but not enabled.");
 #endif
-                        break;
-                    }
-                    default:
-                        break;
+                    break;
                 }
+                default:
+                    break;
             }
+        }
+        if (tfl_ctx->models[i].model_pointer) {
             wasm_runtime_free(tfl_ctx->models[i].model_pointer);
         }
         tfl_ctx->models[i].model_pointer = NULL;

+ 3 - 2
core/iwasm/libraries/wasi-nn/test/utils.c

@@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
 wasi_nn_error
 wasm_load_by_name(const char *model_name, graph *g)
 {
-    wasi_nn_error res = load_by_name(model_name, g);
+    wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
     return res;
 }
 
@@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
               uint32_t num_output_tensors)
 {
     graph graph;
-    if (wasm_load(model_name, &graph, target) != success) {
+
+    if (wasm_load_by_name(model_name, &graph) != success) {
         NN_ERR_PRINTF("Error when loading model.");
         exit(1);
     }