Просмотр исходного кода

wasi_nn_llamacpp.c: validate input tensor type/dimensions (#4442)

YAMAMOTO Takashi 6 месяцев назад
Родитель
Сommit
ee056d8076
1 измененных файлов с 12 добавлено и 0 удалено
  1. 12 0
      core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

+ 12 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

@@ -411,6 +411,18 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
     char *prompt_text = (char *)wasi_nn_tensor->data.buf;
     uint32_t prompt_text_len = wasi_nn_tensor->data.size;
 
+    // note: buf[0] == 1 is a workaround for
+    // https://github.com/second-state/WasmEdge-WASINN-examples/issues/196.
+    // we may remove it in future.
+    if (wasi_nn_tensor->type != u8 || wasi_nn_tensor->dimensions->size != 1
+        || !(wasi_nn_tensor->dimensions->buf[0] == 1
+             || wasi_nn_tensor->dimensions->buf[0] == prompt_text_len)) {
+        return invalid_argument;
+    }
+    if (wasi_nn_tensor->dimensions->buf[0] == 1 && prompt_text_len != 1) {
+        NN_WARN_PRINTF("Ignoring seemingly wrong input tensor dimensions.");
+    }
+
 #ifndef NDEBUG
     NN_DBG_PRINTF("--------------------------------------------------");
     NN_DBG_PRINTF("prompt_text: %.*s", (int)prompt_text_len, prompt_text);