Просмотр исходного кода

wasi_nn_llamacpp.c: reject invalid graph and execution context (#4422)

* return valid graph and execution context instead of using stack garbage.
  (always 0 for now because we don't implement multiple graph/context
  for this backend.)

* validate user-given graph and execution context values. reject
  invalid ones.
YAMAMOTO Takashi 6 месяцев назад
Родитель
Сommit
da6019f749
1 измененных файлов с 33 добавлено и 0 удалено
  1. 33 0
      core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

+ 33 - 0
core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c

@@ -305,6 +305,11 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
 {
     struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
 
+    if (backend_ctx->model != NULL) {
+        // we only implement a single graph
+        return unsupported_operation;
+    }
+
     // make sure backend_ctx->config is initialized
 
     struct llama_model_params model_params =
@@ -323,6 +328,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
 #endif
 
     backend_ctx->model = model;
+    *g = 0;
 
     return success;
 }
@@ -363,6 +369,16 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
 {
     struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
 
+    if (g != 0 || backend_ctx->model == NULL) {
+        // we only implement a single graph
+        return runtime_error;
+    }
+
+    if (backend_ctx->ctx != NULL) {
+        // we only implement a single context
+        return unsupported_operation;
+    }
+
     struct llama_context_params ctx_params =
         llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config);
     struct llama_context *llama_ctx =
@@ -373,6 +389,7 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
     }
 
     backend_ctx->ctx = llama_ctx;
+    *exec_ctx = 0;
 
     NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict,
                    llama_n_ctx(backend_ctx->ctx));
@@ -384,6 +401,12 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
           tensor *wasi_nn_tensor)
 {
     struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
+
+    if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
+        // we only implement a single context
+        return runtime_error;
+    }
+
     // tensor->data is the prompt string.
     char *prompt_text = (char *)wasi_nn_tensor->data.buf;
     uint32_t prompt_text_len = wasi_nn_tensor->data.size;
@@ -433,6 +456,11 @@ compute(void *ctx, graph_execution_context exec_ctx)
     struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
     wasi_nn_error ret = runtime_error;
 
+    if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
+        // we only implement a single context
+        return runtime_error;
+    }
+
     // reset the generation buffer
     if (backend_ctx->generation == NULL) {
         backend_ctx->generation =
@@ -554,6 +582,11 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
 {
     struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
 
+    if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
+        // we only implement a single context
+        return runtime_error;
+    }
+
     // Compatibility with WasmEdge
     if (index > 1) {
         NN_ERR_PRINTF("Invalid output index %d", index);