Просмотр исходного кода

wasi_nn.h: make this compatible with wasi_ephemeral_nn (#4330)

- wasi_nn.h: make this compatible with wasi_ephemeral_nn
cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4323

- fix WASM_ENABLE_WASI_EPHEMERAL_NN build
this structure is used by host logic as well.
ideally definitions for wasm and host should be separated.
until it happens, check __wasm__ to avoid the breakage.
YAMAMOTO Takashi 7 месяцев назад
Родитель
Сommit
4d6b8dcd5d

+ 19 - 0
core/iwasm/libraries/wasi-nn/include/wasi_nn.h

@@ -15,21 +15,33 @@
 #include <stdint.h>
 #include "wasi_nn_types.h"
 
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+#define WASI_NN_IMPORT(name) \
+    __attribute__((import_module("wasi_ephemeral_nn"), import_name(name)))
+#else
 #define WASI_NN_IMPORT(name) \
     __attribute__((import_module("wasi_nn"), import_name(name)))
+#endif
 
 /**
  * @brief Load an opaque sequence of bytes to use for inference.
  *
  * @param builder   Model builder.
+ * @param builder_len The size of model builder.
  * @param encoding  Model encoding.
  * @param target    Execution target.
  * @param g         Graph.
  * @return wasi_nn_error    Execution status.
  */
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+wasi_nn_error
+load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding,
+     execution_target target, graph *g) WASI_NN_IMPORT("load");
+#else
 wasi_nn_error
 load(graph_builder_array *builder, graph_encoding encoding,
      execution_target target, graph *g) WASI_NN_IMPORT("load");
+#endif
 
 wasi_nn_error
 load_by_name(const char *name, uint32_t name_len, graph *g)
@@ -84,9 +96,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
  * copied number of bytes.
  * @return wasi_nn_error                Execution status.
  */
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+wasi_nn_error
+get_output(graph_execution_context ctx, uint32_t index,
+           tensor_data output_tensor, uint32_t output_tensor_max_size,
+           uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
+#else
 wasi_nn_error
 get_output(graph_execution_context ctx, uint32_t index,
            tensor_data output_tensor, uint32_t *output_tensor_size)
     WASI_NN_IMPORT("get_output");
+#endif
 
 #endif

+ 4 - 0
core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h

@@ -77,7 +77,11 @@ typedef struct {
     // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
     // represent a tensor containing a single value, use `[1]` for the tensor
     // dimensions.
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
+    tensor_dimensions dimensions;
+#else
     tensor_dimensions *dimensions;
+#endif
     // Describe the type of element in the tensor (e.g., f32).
     uint8_t type;
     uint8_t _pad[3];