Sfoglia il codice sorgente

wasi_nn_tensorflowlite.cpp: fix get_output return size (#4390)

it should be byte size, not the number of (fp32) values.

i'm ambivalent about how to deal with the compatibility for
the legacy wamr-specific "wasi_nn". for now, i avoided changing it.
(so that existing tests using the legacy abi, namely test_tensorflow.c
and test_tensorflow_quantized.c, passes as they are.)
if we have any users who still want to use the legacy abi,
i suppose they consider the compatibility is more important
than the consistency with other backends.

cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4376
YAMAMOTO Takashi 8 mesi fa
parent
commit
8289452abb
1 ha cambiato i file con 57 aggiunte e 16 eliminazioni
  1. 57 16
      core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

+ 57 - 16
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

@@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
         return too_large;
         return too_large;
     }
     }
 
 
-    uint32_t model_tensor_size = 1;
-    for (int i = 0; i < (int)tensor->dims->size; ++i)
-        model_tensor_size *= (uint32_t)tensor->dims->data[i];
-
-    if (*output_tensor_size < model_tensor_size) {
-        NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
-        return too_large;
-    }
-
     if (tensor->quantization.type == kTfLiteNoQuantization) {
     if (tensor->quantization.type == kTfLiteNoQuantization) {
         NN_DBG_PRINTF("No quantization information");
         NN_DBG_PRINTF("No quantization information");
-        float *ot =
-            tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
-                index);
-
-        int size = model_tensor_size * sizeof(float);
-        bh_memcpy_s(output_tensor, size, ot, size);
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+        if (*output_tensor_size < tensor->bytes) {
+            NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
+            return too_large;
+        }
+#else
+        /*
+         * for now, maintain the bug-to-bug compatibility with the old abi,
+         * where the size here is the number of fp32, not bytes.
+         */
+        if (*output_tensor_size < tensor->bytes / sizeof(float)) {
+            NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
+            return too_large;
+        }
+#endif
+        bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
+                    tensor->bytes);
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+        *output_tensor_size = tensor->bytes;
+#else
+        /*
+         * for now, maintain the bug-to-bug compatibility with the old abi,
+         * where the size here is the number of fp32, not bytes.
+         */
+        *output_tensor_size = tensor->bytes / sizeof(float);
+#endif
     }
     }
     else { // TODO: Assuming uint8 quantized networks.
     else { // TODO: Assuming uint8 quantized networks.
         TfLiteAffineQuantization *quant_info =
         TfLiteAffineQuantization *quant_info =
@@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
             NN_ERR_PRINTF("Quantization per channel is not supported");
             NN_ERR_PRINTF("Quantization per channel is not supported");
             return runtime_error;
             return runtime_error;
         }
         }
+
+        uint32_t model_tensor_size = 1;
+        for (int i = 0; i < (int)tensor->dims->size; ++i)
+            model_tensor_size *= (uint32_t)tensor->dims->data[i];
+
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+        if (*output_tensor_size / sizeof(float) < model_tensor_size) {
+            NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
+            return too_large;
+        }
+#else
+        /*
+         * for now, maintain the bug-to-bug compatibility with the old abi,
+         * where the size here is the number of fp32, not bytes.
+         */
+        if (*output_tensor_size < model_tensor_size) {
+            NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
+            return too_large;
+        }
+#endif
+
         uint8_t *ot = tfl_ctx->interpreters[ctx]
         uint8_t *ot = tfl_ctx->interpreters[ctx]
                           .interpreter->typed_output_tensor<uint8_t>(index);
                           .interpreter->typed_output_tensor<uint8_t>(index);
 
 
@@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
         for (uint32_t i = 0; i < model_tensor_size; ++i) {
         for (uint32_t i = 0; i < model_tensor_size; ++i) {
             output_tensor_f[i] = (ot[i] - zero_point) * scale;
             output_tensor_f[i] = (ot[i] - zero_point) * scale;
         }
         }
+
+#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
+        *output_tensor_size = model_tensor_size * sizeof(float);
+#else
+        /*
+         * for now, maintain the bug-to-bug compatibility with the old abi,
+         * where the size here is the number of fp32, not bytes.
+         */
+        *output_tensor_size = model_tensor_size;
+#endif
     }
     }
 
 
-    *output_tensor_size = model_tensor_size;
     return success;
     return success;
 }
 }