Эх сурвалжийг харах

Add `wasi_ephemeral_nn` module support (#3241)

Add `wasi_ephemeral_nn` module support with optional cmake variable,
which was mentioned in #3229.
Xu Jinyang 1 жил өмнө
parent
commit
cef88deedb

+ 4 - 0
build-scripts/config_common.cmake

@@ -430,6 +430,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
   if (DEFINED WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH)
       add_definitions (-DWASM_WASI_NN_EXTERNAL_DELEGATE_PATH="${WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH}")
   endif ()
+  if (WAMR_BUILD_WASI_EPHEMERAL_NN EQUAL 1)
+      message ("     WASI-NN: WASI-Ephemeral-NN enabled")
+      add_definitions (-DWASM_ENABLE_WASI_EPHEMERAL_NN=1)
+  endif()
 endif ()
 if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1)
   add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1)

+ 4 - 0
core/config.h

@@ -152,6 +152,10 @@
 #define WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE 0
 #endif
 
+#ifndef WASM_ENABLE_WASI_EPHEMERAL_NN
+#define WASM_ENABLE_WASI_EPHEMERAL_NN 0
+#endif
+
 /* Default disable libc emcc */
 #ifndef WASM_ENABLE_LIBC_EMCC
 #define WASM_ENABLE_LIBC_EMCC 0

+ 6 - 1
core/iwasm/common/wasm_native.c

@@ -567,7 +567,12 @@ wasm_native_init()
 
 #if WASM_ENABLE_WASI_NN != 0
     n_native_symbols = get_wasi_nn_export_apis(&native_symbols);
-    if (!wasm_native_register_natives("wasi_nn", native_symbols,
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+#define wasi_nn_module_name "wasi_ephemeral_nn"
+#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
+#define wasi_nn_module_name "wasi_nn"
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
+    if (!wasm_native_register_natives(wasi_nn_module_name, native_symbols,
                                       n_native_symbols))
         goto fail;
 #endif

+ 43 - 7
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c

@@ -23,24 +23,47 @@ graph_builder_app_native(wasm_module_inst_t instance,
     return success;
 }
 
+/**
+ * builder_array_wasm is consisted of {builder_wasm, size}
+ */
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+error
+graph_builder_array_app_native(wasm_module_inst_t instance,
+                               graph_builder_wasm *builder_wasm, uint32_t size,
+                               graph_builder_array *builder_array)
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
 error
 graph_builder_array_app_native(wasm_module_inst_t instance,
                                graph_builder_array_wasm *builder_array_wasm,
                                graph_builder_array *builder_array)
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 {
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+#define array_size size
+#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
+#define array_size builder_array_wasm->size
+
     if (!wasm_runtime_validate_native_addr(
             instance, builder_array_wasm,
             (uint64)sizeof(graph_builder_array_wasm))) {
         NN_ERR_PRINTF("builder_array_wasm is invalid");
         return invalid_argument;
     }
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
-    NN_DBG_PRINTF("Graph builder array contains %d elements",
-                  builder_array_wasm->size);
+    NN_DBG_PRINTF("Graph builder array contains %d elements", array_size);
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    if (!wasm_runtime_validate_native_addr(instance, builder_wasm,
+                                           (uint64)array_size
+                                               * sizeof(graph_builder_wasm))) {
+        NN_ERR_PRINTF("builder_wasm is invalid");
+        return invalid_argument;
+    }
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     if (!wasm_runtime_validate_app_addr(
             instance, (uint64)builder_array_wasm->buf_offset,
-            (uint64)builder_array_wasm->size * sizeof(graph_builder_wasm))) {
+            (uint64)array_size * sizeof(graph_builder_wasm))) {
         NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid");
         return invalid_argument;
     }
@@ -48,13 +71,14 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
     graph_builder_wasm *builder_wasm =
         (graph_builder_wasm *)wasm_runtime_addr_app_to_native(
             instance, (uint64)builder_array_wasm->buf_offset);
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
     graph_builder *builder = (graph_builder *)wasm_runtime_malloc(
-        builder_array_wasm->size * sizeof(graph_builder));
+        array_size * sizeof(graph_builder));
     if (builder == NULL)
         return missing_memory;
 
-    for (uint32_t i = 0; i < builder_array_wasm->size; ++i) {
+    for (uint32_t i = 0; i < array_size; ++i) {
         error res;
         if (success
             != (res = graph_builder_app_native(instance, &builder_wasm[i],
@@ -68,23 +92,31 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
     }
 
     builder_array->buf = builder;
-    builder_array->size = builder_array_wasm->size;
+    builder_array->size = array_size;
     return success;
+#undef array_size
 }
 
 static error
 tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
                        tensor_wasm *input_tensor_wasm, tensor_data *data)
 {
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+#define data_size input_tensor_wasm->data_size
+#else
+#define data_size total_elements
+#endif
+
     if (!wasm_runtime_validate_app_addr(instance,
                                         (uint64)input_tensor_wasm->data_offset,
-                                        (uint64)total_elements)) {
+                                        (uint64)data_size)) {
         NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
         return invalid_argument;
     }
     *data = (tensor_data)wasm_runtime_addr_app_to_native(
         instance, (uint64)input_tensor_wasm->data_offset);
     return success;
+#undef data_size
 }
 
 static error
@@ -92,6 +124,9 @@ tensor_dimensions_app_native(wasm_module_inst_t instance,
                              tensor_wasm *input_tensor_wasm,
                              tensor_dimensions **dimensions)
 {
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    tensor_dimensions_wasm *dimensions_wasm = &input_tensor_wasm->dimensions;
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     if (!wasm_runtime_validate_app_addr(
             instance, (uint64)input_tensor_wasm->dimensions_offset,
             (uint64)sizeof(tensor_dimensions_wasm))) {
@@ -102,6 +137,7 @@ tensor_dimensions_app_native(wasm_module_inst_t instance,
     tensor_dimensions_wasm *dimensions_wasm =
         (tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native(
             instance, (uint64)input_tensor_wasm->dimensions_offset);
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
     if (!wasm_runtime_validate_app_addr(instance,
                                         (uint64)dimensions_wasm->buf_offset,

+ 14 - 0
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h

@@ -34,15 +34,29 @@ typedef struct {
 } tensor_dimensions_wasm;
 
 typedef struct {
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    tensor_dimensions_wasm dimensions;
+    tensor_type type;
+    uint32_t data_offset;
+    uint32_t data_size;
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     uint32_t dimensions_offset;
     tensor_type type;
     uint32_t data_offset;
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 } tensor_wasm;
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+error
+graph_builder_array_app_native(wasm_module_inst_t instance,
+                               graph_builder_wasm *builder_wasm, uint32_t size,
+                               graph_builder_array *builder_array);
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
 error
 graph_builder_array_app_native(wasm_module_inst_t instance,
                                graph_builder_array_wasm *builder,
                                graph_builder_array *builder_native);
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
 error
 tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor,

+ 35 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn.c

@@ -189,9 +189,16 @@ is_model_initialized(WASINNContext *wasi_nn_ctx)
 
 /* WASI-NN implementation */
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+error
+wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
+             uint32_t builder_wasm_size, graph_encoding encoding,
+             execution_target target, graph *g)
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
 error
 wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
              graph_encoding encoding, execution_target target, graph *g)
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 {
     NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
                   target);
@@ -206,10 +213,17 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
 
     error res;
     graph_builder_array builder_native = { 0 };
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    if (success
+        != (res = graph_builder_array_app_native(
+                instance, builder, builder_wasm_size, &builder_native)))
+        return res;
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     if (success
         != (res = graph_builder_array_app_native(instance, builder,
                                                  &builder_native)))
         return res;
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 
     if (!wasm_runtime_validate_native_addr(instance, g,
                                            (uint64)sizeof(graph))) {
@@ -315,10 +329,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
     return res;
 }
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+error
+wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
+                   uint32_t index, tensor_data output_tensor,
+                   uint32_t output_tensor_len, uint32_t *output_tensor_size)
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
 error
 wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                    uint32_t index, tensor_data output_tensor,
                    uint32_t *output_tensor_size)
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 {
     NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx,
                   index);
@@ -337,8 +358,14 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
         return invalid_argument;
     }
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    res = lookup[wasi_nn_ctx->current_encoding].get_output(
+        wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, &output_tensor_len);
+    *output_tensor_size = output_tensor_len;
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     res = lookup[wasi_nn_ctx->current_encoding].get_output(
         wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
     NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
                   res, *output_tensor_size);
     return res;
@@ -352,11 +379,19 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
 /* clang-format on */
 
 static NativeSymbol native_symbols_wasi_nn[] = {
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+    REG_NATIVE_FUNC(load, "(*iii*)i"),
+    REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
+    REG_NATIVE_FUNC(set_input, "(ii*)i"),
+    REG_NATIVE_FUNC(compute, "(i)i"),
+    REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
+#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
     REG_NATIVE_FUNC(load, "(*ii*)i"),
     REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
     REG_NATIVE_FUNC(set_input, "(ii*)i"),
     REG_NATIVE_FUNC(compute, "(i)i"),
     REG_NATIVE_FUNC(get_output, "(ii**)i"),
+#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
 };
 
 uint32_t

+ 3 - 0
doc/build_wamr.md

@@ -107,6 +107,9 @@ cmake -DWAMR_BUILD_PLATFORM=linux -DWAMR_BUILD_TARGET=ARM
 
 - **WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH**=Path to the external delegate shared library (e.g. `libedgetpu.so.1.0` for Coral USB)
 
+#### **Enable lib wasi-nn with `wasi_ephemeral_nn` module support**
+- **WAMR_BUILD_WASI_EPHEMERAL_NN**=1/0, default to disable if not set
+
 #### **Disable boundary check with hardware trap**
 - **WAMR_DISABLE_HW_BOUND_CHECK**=1/0, default to enable if not set and supported by platform
 > Note: by default only platform [linux/darwin/android/windows/vxworks 64-bit](https://github.com/bytecodealliance/wasm-micro-runtime/blob/5fb5119239220b0803e7045ca49b0a29fe65e70e/core/shared/platform/linux/platform_internal.h#L81) will enable the boundary check with hardware trap feature, for 32-bit platforms it's automatically disabled even when the flag is set to 0, and the wamrc tool will generate AOT code without boundary check instructions in all 64-bit targets except SGX to improve performance. The boundary check includes linear memory access boundary and native stack access boundary, if `WAMR_DISABLE_STACK_HW_BOUND_CHECK` below isn't set.