Преглед изворни кода

Add implementation for wasi_thread_spawn() (#1786)

For now this implementation uses thread manager.

Not sure whether thread manager is needed in that case. In the future there'll be likely another syscall added (for pthread_exit) and for that we might need some kind of thread management - with that in mind, we keep thread manager for now and will refactor this later if needed.
Marcin Kolny пре 3 година
родитељ
комит
929d5942b9

+ 1 - 0
build-scripts/runtime_lib.cmake

@@ -127,6 +127,7 @@ endif ()
 if (WAMR_BUILD_LIB_WASI_THREADS EQUAL 1)
     include (${IWASM_DIR}/libraries/lib-wasi-threads/lib_wasi_threads.cmake)
     # Enable the dependent feature if lib wasi threads is enabled
+    set (WAMR_BUILD_THREAD_MGR 1)
     set (WAMR_BUILD_BULK_MEMORY 1)
     set (WAMR_BUILD_SHARED_MEMORY 1)
 endif ()

+ 14 - 1
core/iwasm/common/wasm_native.c

@@ -54,6 +54,12 @@ get_lib_pthread_export_apis(NativeSymbol **p_lib_pthread_apis);
 #endif
 
 #if WASM_ENABLE_LIB_WASI_THREADS != 0
+bool
+lib_wasi_threads_init(void);
+
+void
+lib_wasi_threads_destroy(void);
+
 uint32
 get_lib_wasi_threads_export_apis(NativeSymbol **p_lib_wasi_threads_apis);
 #endif
@@ -444,6 +450,9 @@ wasm_native_init()
 #endif
 
 #if WASM_ENABLE_LIB_WASI_THREADS != 0
+    if (!lib_wasi_threads_init())
+        goto fail;
+
     n_native_symbols = get_lib_wasi_threads_export_apis(&native_symbols);
     if (n_native_symbols > 0
         && !wasm_native_register_natives("wasi", native_symbols,
@@ -471,7 +480,7 @@ wasm_native_init()
     n_native_symbols = get_wasi_nn_export_apis(&native_symbols);
     if (!wasm_native_register_natives("wasi_nn", native_symbols,
                                       n_native_symbols))
-        return false;
+        goto fail;
 #endif
 
     return true;
@@ -495,6 +504,10 @@ wasm_native_destroy()
     lib_pthread_destroy();
 #endif
 
+#if WASM_ENABLE_LIB_WASI_THREADS != 0
+    lib_wasi_threads_destroy();
+#endif
+
     node = g_native_symbols_list;
     while (node) {
         node_next = node->next;

+ 139 - 4
core/iwasm/libraries/lib-wasi-threads/lib_wasi_threads_wrapper.c

@@ -3,7 +3,9 @@
  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  */
 
+#include "bh_log.h"
 #include "wasmtime_ssp.h"
+#include "thread_manager.h"
 
 #if WASM_ENABLE_INTERP != 0
 #include "wasm_runtime.h"
@@ -13,10 +15,128 @@
 #include "aot_runtime.h"
 #endif
 
-static __wasi_errno_t
-thread_spawn_wrapper(wasm_exec_env_t exec_env, void *start_arg)
+static const char *THREAD_START_FUNCTION = "wasi_thread_start";
+
+static korp_mutex thread_id_lock;
+
+typedef struct {
+    /* app's entry function */
+    wasm_function_inst_t start_func;
+    /* arg of the app's entry function */
+    uint32 arg;
+    /* thread id passed to the app */
+    int32 thread_id;
+} ThreadStartArg;
+
+static int32
+allocate_thread_id()
+{
+    static int32 thread_id = 0;
+
+    int32 id;
+
+    os_mutex_lock(&thread_id_lock);
+    id = thread_id++;
+    os_mutex_unlock(&thread_id_lock);
+    return id;
+}
+
+static void *
+thread_start(void *arg)
 {
-    return __WASI_ENOSYS;
+    wasm_exec_env_t exec_env = (wasm_exec_env_t)arg;
+    wasm_module_inst_t module_inst = get_module_inst(exec_env);
+    ThreadStartArg *thread_arg = exec_env->thread_arg;
+    uint32 argv[2];
+
+    wasm_exec_env_set_thread_info(exec_env);
+    argv[0] = thread_arg->thread_id;
+    argv[1] = thread_arg->arg;
+
+    if (!wasm_runtime_call_wasm(exec_env, thread_arg->start_func, 2, argv)) {
+        if (wasm_runtime_get_exception(module_inst))
+            wasm_cluster_spread_exception(exec_env);
+    }
+
+    /* routine exit, destroy instance */
+    wasm_runtime_deinstantiate_internal(module_inst, true);
+
+    wasm_runtime_free(thread_arg);
+    exec_env->thread_arg = NULL;
+
+    return NULL;
+}
+
+static int32
+thread_spawn_wrapper(wasm_exec_env_t exec_env, uint32 start_arg)
+{
+    wasm_module_t module = wasm_exec_env_get_module(exec_env);
+    wasm_module_inst_t module_inst = get_module_inst(exec_env);
+    wasm_module_inst_t new_module_inst = NULL;
+    ThreadStartArg *thread_start_arg = NULL;
+    wasm_function_inst_t start_func;
+    int32 thread_id;
+    uint32 stack_size = 8192;
+    int32 ret = -1;
+#if WASM_ENABLE_LIBC_WASI != 0
+    WASIContext *wasi_ctx;
+#endif
+
+    bh_assert(module);
+    bh_assert(module_inst);
+
+    stack_size = ((WASMModuleInstance *)module_inst)->default_wasm_stack_size;
+
+    if (!(new_module_inst = wasm_runtime_instantiate_internal(
+              module, true, stack_size, 0, NULL, 0)))
+        return -1;
+
+    wasm_runtime_set_custom_data_internal(
+        new_module_inst, wasm_runtime_get_custom_data(module_inst));
+
+#if WASM_ENABLE_LIBC_WASI != 0
+    wasi_ctx = wasm_runtime_get_wasi_ctx(module_inst);
+    if (wasi_ctx)
+        wasm_runtime_set_wasi_ctx(new_module_inst, wasi_ctx);
+#endif
+
+    start_func = wasm_runtime_lookup_function(new_module_inst,
+                                              THREAD_START_FUNCTION, NULL);
+    if (!start_func) {
+        LOG_ERROR("Failed to find thread start function %s",
+                  THREAD_START_FUNCTION);
+        goto thread_spawn_fail;
+    }
+
+    if (!(thread_start_arg = wasm_runtime_malloc(sizeof(ThreadStartArg)))) {
+        LOG_ERROR("Runtime args allocation failed");
+        goto thread_spawn_fail;
+    }
+
+    thread_start_arg->thread_id = thread_id = allocate_thread_id();
+    thread_start_arg->arg = start_arg;
+    thread_start_arg->start_func = start_func;
+
+    os_mutex_lock(&exec_env->wait_lock);
+    ret = wasm_cluster_create_thread(exec_env, new_module_inst, thread_start,
+                                     thread_start_arg);
+    if (ret != 0) {
+        os_mutex_unlock(&exec_env->wait_lock);
+        LOG_ERROR("Failed to spawn a new thread");
+        goto thread_spawn_fail;
+    }
+    os_mutex_unlock(&exec_env->wait_lock);
+
+    return thread_id;
+
+thread_spawn_fail:
+    if (new_module_inst)
+        wasm_runtime_deinstantiate_internal(new_module_inst, true);
+
+    if (thread_start_arg)
+        wasm_runtime_free(thread_start_arg);
+
+    return -1;
 }
 
 /* clang-format off */
@@ -25,7 +145,7 @@ thread_spawn_wrapper(wasm_exec_env_t exec_env, void *start_arg)
 /* clang-format on */
 
 static NativeSymbol native_symbols_lib_wasi_threads[] = { REG_NATIVE_FUNC(
-    thread_spawn, "(*)i") };
+    thread_spawn, "(i)i") };
 
 uint32
 get_lib_wasi_threads_export_apis(NativeSymbol **p_lib_wasi_threads_apis)
@@ -33,3 +153,18 @@ get_lib_wasi_threads_export_apis(NativeSymbol **p_lib_wasi_threads_apis)
     *p_lib_wasi_threads_apis = native_symbols_lib_wasi_threads;
     return sizeof(native_symbols_lib_wasi_threads) / sizeof(NativeSymbol);
 }
+
+bool
+lib_wasi_threads_init(void)
+{
+    if (0 != os_mutex_init(&thread_id_lock))
+        return false;
+
+    return true;
+}
+
+void
+lib_wasi_threads_destroy(void)
+{
+    os_mutex_destroy(&thread_id_lock);
+}

+ 12 - 6
samples/wasi-threads/wasm-apps/no_pthread.c

@@ -8,6 +8,7 @@
 
 #include <stdlib.h>
 #include <stdio.h>
+#include <assert.h>
 #include <wasi/api.h>
 
 static const int64_t SECOND = 1000 * 1000 * 1000;
@@ -15,6 +16,7 @@ static const int64_t SECOND = 1000 * 1000 * 1000;
 typedef struct {
     int th_ready;
     int value;
+    int thread_id;
 } shared_t;
 
 __attribute__((export_name("wasi_thread_start"))) void
@@ -25,6 +27,7 @@ wasi_thread_start(int thread_id, int *start_arg)
     printf("New thread ID: %d, starting parameter: %d\n", thread_id,
            data->value);
 
+    data->thread_id = thread_id;
     data->value += 8;
     printf("Updated value: %d\n", data->value);
 
@@ -35,12 +38,12 @@ wasi_thread_start(int thread_id, int *start_arg)
 int
 main(int argc, char **argv)
 {
-    shared_t data = { 0, 52 };
-    __wasi_errno_t err;
+    shared_t data = { 0, 52, -1 };
+    int thread_id;
 
-    err = __wasi_thread_spawn(&data);
-    if (err != __WASI_ERRNO_SUCCESS) {
-        printf("Failed to create thread: %d\n", err);
+    thread_id = __wasi_thread_spawn(&data);
+    if (thread_id < 0) {
+        printf("Failed to create thread: %d\n", thread_id);
         return EXIT_FAILURE;
     }
 
@@ -49,7 +52,10 @@ main(int argc, char **argv)
         return EXIT_FAILURE;
     }
 
-    printf("Thread completed, new value: %d\n", data.value);
+    printf("Thread completed, new value: %d, thread id: %d\n", data.value,
+           data.thread_id);
+
+    assert(thread_id == data.thread_id);
 
     return EXIT_SUCCESS;
 }