Просмотр исходного кода

wasi-nn: fix tensor_data abi for wasi_ephemeral_nn (#4379)

it's "(list u8)" in the witx definition.

the new definition matches both of our own host definition
(struct tensor_wasm) and wasmtime.

cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4352
YAMAMOTO Takashi 8 месяцев назад
Родитель
Сommit
aa53d648fa

+ 3 - 4
core/iwasm/libraries/wasi-nn/include/wasi_nn.h

@@ -108,14 +108,13 @@ WASI_NN_NAME(compute)
 WASI_NN_ERROR_TYPE
 WASI_NN_NAME(get_output)
 (WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
- WASI_NN_NAME(tensor_data) output_tensor, uint32_t output_tensor_max_size,
+ uint8_t *output_tensor, uint32_t output_tensor_max_size,
  uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
 #else
 WASI_NN_ERROR_TYPE
 WASI_NN_NAME(get_output)
-(graph_execution_context ctx, uint32_t index,
- WASI_NN_NAME(tensor_data) output_tensor, uint32_t *output_tensor_size)
-    WASI_NN_IMPORT("get_output");
+(graph_execution_context ctx, uint32_t index, uint8_t *output_tensor,
+ uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
 #endif
 
 #endif

+ 7 - 0
core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h

@@ -99,7 +99,14 @@ typedef enum {
 // 4-byte f32 elements would have a data array of length 16). Naturally, this
 // representation requires some knowledge of how to lay out data in
 // memory--e.g., using row-major ordering--and could perhaps be improved.
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
+typedef struct {
+    uint8_t *buf;
+    uint32_t size;
+} WASI_NN_NAME(tensor_data);
+#else
 typedef uint8_t *WASI_NN_NAME(tensor_data);
+#endif
 
 // A tensor.
 typedef struct {

+ 2 - 1
wamr-wasi-extensions/samples/nn-cli/main.c

@@ -266,7 +266,8 @@ set_input(char *options)
     wasi_ephemeral_nn_error nnret;
     wasi_ephemeral_nn_graph_execution_context c =
         map_get(&contexts, context_id);
-    tensor.data = buf;
+    tensor.data.buf = buf;
+    tensor.data.size = sz;
     nnret = wasi_ephemeral_nn_set_input(c, idx, &tensor);
     unmap_file(buf, sz);
     if (nnret != wasi_ephemeral_nn_error_success) {

+ 2 - 1
wamr-wasi-extensions/samples/nn/app.c

@@ -147,7 +147,8 @@ main(int argc, char **argv)
     wasi_ephemeral_nn_tensor tensor = {
         .dimensions = { .buf = (uint32_t[]){1, 3, 224, 224,}, .size = 4, },
         .type = wasi_ephemeral_nn_type_fp32,
-        .data = tensordata,
+        .data.buf = tensordata,
+        .data.size = tensordatasz,
     };
     nnret = wasi_ephemeral_nn_set_input(ctx, 0, &tensor);
     unmap_file(tensordata, tensordatasz);