Bladeren bron

Fix/Simplify the atomic.wait/nofity implementations (#2044)

Use the shared memory's shared_mem_lock to lock the whole atomic.wait and
atomic.notify processes, and use it for os_cond_reltimedwait and os_cond_notify,
so as to make the whole processes actual atomic operations:
the original implementation accesses the wait address with shared_mem_lock
and uses wait_node->wait_lock for os_cond_reltimedwait, which is not an atomic
operation.

And remove the unnecessary wait_map_lock and wait_lock, since the whole
processes are already locked by shared_mem_lock.
Wenyong Huang 2 jaren geleden
bovenliggende
commit
49d439a3bc

+ 75 - 135
core/iwasm/common/wasm_shared_memory.c

@@ -21,28 +21,20 @@ enum {
 /* clang-format on */
 
 typedef struct AtomicWaitInfo {
-    korp_mutex wait_list_lock;
     bh_list wait_list_head;
     bh_list *wait_list;
     /* WARNING: insert to the list allowed only in acquire_wait_info
-    otherwise there will be data race as described in PR #2016 */
+       otherwise there will be data race as described in PR #2016 */
 } AtomicWaitInfo;
 
 typedef struct AtomicWaitNode {
     bh_list_link l;
     uint8 status;
-    korp_mutex wait_lock;
     korp_cond wait_cond;
 } AtomicWaitNode;
 
-typedef struct AtomicWaitAddressArgs {
-    uint32 index;
-    void **addr;
-} AtomicWaitAddressArgs;
-
 /* Atomic wait map */
 static HashMap *wait_map;
-static korp_mutex wait_map_lock;
 
 static uint32
 wait_address_hash(void *address);
@@ -59,17 +51,11 @@ wasm_shared_memory_init()
     if (os_mutex_init(&shared_memory_list_lock) != 0)
         return false;
 
-    if (os_mutex_init(&wait_map_lock) != 0) {
-        os_mutex_destroy(&shared_memory_list_lock);
-        return false;
-    }
-
     /* wait map not exists, create new map */
     if (!(wait_map = bh_hash_map_create(32, true, (HashFunc)wait_address_hash,
                                         (KeyEqualFunc)wait_address_equal, NULL,
                                         destroy_wait_info))) {
         os_mutex_destroy(&shared_memory_list_lock);
-        os_mutex_destroy(&wait_map_lock);
         return false;
     }
 
@@ -79,11 +65,8 @@ wasm_shared_memory_init()
 void
 wasm_shared_memory_destroy()
 {
+    bh_hash_map_destroy(wait_map);
     os_mutex_destroy(&shared_memory_list_lock);
-    os_mutex_destroy(&wait_map_lock);
-    if (wait_map) {
-        bh_hash_map_destroy(wait_map);
-    }
 }
 
 static WASMSharedMemNode *
@@ -224,7 +207,7 @@ notify_wait_list(bh_list *wait_list, uint32 count)
     AtomicWaitNode *node, *next;
     uint32 i, notify_count = count;
 
-    if ((count == UINT32_MAX) || (count > wait_list->len))
+    if (count > wait_list->len)
         notify_count = wait_list->len;
 
     node = bh_list_first_elem(wait_list);
@@ -235,11 +218,9 @@ notify_wait_list(bh_list *wait_list, uint32 count)
         bh_assert(node);
         next = bh_list_elem_next(node);
 
-        os_mutex_lock(&node->wait_lock);
         node->status = S_NOTIFIED;
         /* wakeup */
         os_cond_signal(&node->wait_cond);
-        os_mutex_unlock(&node->wait_lock);
 
         node = next;
     }
@@ -253,13 +234,10 @@ acquire_wait_info(void *address, AtomicWaitNode *wait_node)
     AtomicWaitInfo *wait_info = NULL;
     bh_list_status ret;
 
-    os_mutex_lock(&wait_map_lock); /* Make find + insert atomic */
-
     if (address)
         wait_info = (AtomicWaitInfo *)bh_hash_map_find(wait_map, address);
 
     if (!wait_node) {
-        os_mutex_unlock(&wait_map_lock);
         return wait_info;
     }
 
@@ -267,7 +245,7 @@ acquire_wait_info(void *address, AtomicWaitNode *wait_node)
     if (!wait_info) {
         if (!(wait_info = (AtomicWaitInfo *)wasm_runtime_malloc(
                   sizeof(AtomicWaitInfo)))) {
-            goto fail1;
+            return NULL;
         }
         memset(wait_info, 0, sizeof(AtomicWaitInfo));
 
@@ -276,38 +254,17 @@ acquire_wait_info(void *address, AtomicWaitNode *wait_node)
         ret = bh_list_init(wait_info->wait_list);
         bh_assert(ret == BH_LIST_SUCCESS);
 
-        /* init wait list lock */
-        if (0 != os_mutex_init(&wait_info->wait_list_lock)) {
-            goto fail2;
-        }
-
         if (!bh_hash_map_insert(wait_map, address, (void *)wait_info)) {
-            goto fail3;
+            wasm_runtime_free(wait_info);
+            return NULL;
         }
     }
 
-    os_mutex_lock(&wait_info->wait_list_lock);
     ret = bh_list_insert(wait_info->wait_list, wait_node);
-    os_mutex_unlock(&wait_info->wait_list_lock);
     bh_assert(ret == BH_LIST_SUCCESS);
     (void)ret;
 
-    os_mutex_unlock(&wait_map_lock);
-
-    bh_assert(wait_info);
-    (void)ret;
     return wait_info;
-
-fail3:
-    os_mutex_destroy(&wait_info->wait_list_lock);
-
-fail2:
-    wasm_runtime_free(wait_info);
-
-fail1:
-    os_mutex_unlock(&wait_map_lock);
-
-    return NULL;
 }
 
 static void
@@ -321,13 +278,11 @@ destroy_wait_info(void *wait_info)
 
         while (node) {
             next = bh_list_elem_next(node);
-            os_mutex_destroy(&node->wait_lock);
             os_cond_destroy(&node->wait_cond);
             wasm_runtime_free(node);
             node = next;
         }
 
-        os_mutex_destroy(&((AtomicWaitInfo *)wait_info)->wait_list_lock);
         wasm_runtime_free(wait_info);
     }
 }
@@ -336,17 +291,11 @@ static void
 map_try_release_wait_info(HashMap *wait_map_, AtomicWaitInfo *wait_info,
                           void *address)
 {
-    os_mutex_lock(&wait_map_lock);
-    os_mutex_lock(&wait_info->wait_list_lock);
     if (wait_info->wait_list->len > 0) {
-        os_mutex_unlock(&wait_info->wait_list_lock);
-        os_mutex_unlock(&wait_map_lock);
         return;
     }
-    os_mutex_unlock(&wait_info->wait_list_lock);
 
     bh_hash_map_remove(wait_map_, address, NULL, NULL);
-    os_mutex_unlock(&wait_map_lock);
     destroy_wait_info(wait_info);
 }
 
@@ -361,6 +310,7 @@ wasm_runtime_atomic_wait(WASMModuleInstanceCommon *module, void *address,
 #if WASM_ENABLE_THREAD_MGR != 0
     WASMExecEnv *exec_env;
 #endif
+    uint64 timeout_left, timeout_wait, timeout_1sec;
     bool check_ret, is_timeout, no_wait;
 
     bh_assert(module->module_type == Wasm_Module_Bytecode
@@ -383,124 +333,108 @@ wasm_runtime_atomic_wait(WASMModuleInstanceCommon *module, void *address,
         return -1;
     }
 
+#if WASM_ENABLE_THREAD_MGR != 0
+    exec_env =
+        wasm_clusters_search_exec_env((WASMModuleInstanceCommon *)module_inst);
+    bh_assert(exec_env);
+#endif
+
     node = search_module((WASMModuleCommon *)module_inst->module);
+    bh_assert(node);
+
+    /* Lock the shared_mem_lock for the whole atomic wait process,
+       and use it to os_cond_reltimedwait */
     os_mutex_lock(&node->shared_mem_lock);
+
     no_wait = (!wait64 && *(uint32 *)address != (uint32)expect)
               || (wait64 && *(uint64 *)address != expect);
-    os_mutex_unlock(&node->shared_mem_lock);
 
     if (no_wait) {
+        os_mutex_unlock(&node->shared_mem_lock);
         return 1;
     }
 
     if (!(wait_node = wasm_runtime_malloc(sizeof(AtomicWaitNode)))) {
+        os_mutex_unlock(&node->shared_mem_lock);
         wasm_runtime_set_exception(module, "failed to create wait node");
         return -1;
     }
     memset(wait_node, 0, sizeof(AtomicWaitNode));
 
-    if (0 != os_mutex_init(&wait_node->wait_lock)) {
-        wasm_runtime_free(wait_node);
-        return -1;
-    }
-
     if (0 != os_cond_init(&wait_node->wait_cond)) {
-        os_mutex_destroy(&wait_node->wait_lock);
+        os_mutex_unlock(&node->shared_mem_lock);
         wasm_runtime_free(wait_node);
+        wasm_runtime_set_exception(module, "failed to init wait cond");
         return -1;
     }
 
     wait_node->status = S_WAITING;
 
-    /* acquire the wait info, create new one if not exists */
+    /* Acquire the wait info, create new one if not exists */
     wait_info = acquire_wait_info(address, wait_node);
 
     if (!wait_info) {
-        os_mutex_destroy(&wait_node->wait_lock);
+        os_mutex_unlock(&node->shared_mem_lock);
+        os_cond_destroy(&wait_node->wait_cond);
         wasm_runtime_free(wait_node);
         wasm_runtime_set_exception(module, "failed to acquire wait_info");
         return -1;
     }
 
+    /* unit of timeout is nsec, convert it to usec */
+    timeout_left = (uint64)timeout / 1000;
+    timeout_1sec = 1e6;
+
+    while (1) {
+        if (timeout < 0) {
+            /* wait forever until it is notified or terminatied
+               here we keep waiting and checking every second */
+            os_cond_reltimedwait(&wait_node->wait_cond, &node->shared_mem_lock,
+                                 (uint64)timeout_1sec);
+            if (wait_node->status == S_NOTIFIED /* notified by atomic.notify */
 #if WASM_ENABLE_THREAD_MGR != 0
-    exec_env =
-        wasm_clusters_search_exec_env((WASMModuleInstanceCommon *)module_inst);
-    bh_assert(exec_env);
-#endif
-
-    os_mutex_lock(&node->shared_mem_lock);
-    no_wait = (!wait64 && *(uint32 *)address != (uint32)expect)
-              || (wait64 && *(uint64 *)address != expect);
-    os_mutex_unlock(&node->shared_mem_lock);
-
-    /* condition wait start */
-    os_mutex_lock(&wait_node->wait_lock);
-
-    if (!no_wait) {
-        /* unit of timeout is nsec, convert it to usec */
-        uint64 timeout_left = (uint64)timeout / 1000, timeout_wait;
-        uint64 timeout_1sec = 1e6;
-
-        while (1) {
-            if (timeout < 0) {
-                /* wait forever until it is notified or terminatied
-                   here we keep waiting and checking every second */
-                os_cond_reltimedwait(&wait_node->wait_cond,
-                                     &wait_node->wait_lock,
-                                     (uint64)timeout_1sec);
-                if (wait_node->status
-                        == S_NOTIFIED /* notified by atomic.notify */
-#if WASM_ENABLE_THREAD_MGR != 0
-                    /* terminated by other thread */
-                    || wasm_cluster_is_thread_terminated(exec_env)
+                /* terminated by other thread */
+                || wasm_cluster_is_thread_terminated(exec_env)
 #endif
-                ) {
-                    break;
-                }
-                /* continue to wait */
+            ) {
+                break;
             }
-            else {
-                timeout_wait =
-                    timeout_left < timeout_1sec ? timeout_left : timeout_1sec;
-                os_cond_reltimedwait(&wait_node->wait_cond,
-                                     &wait_node->wait_lock, timeout_wait);
-                if (wait_node->status
-                        == S_NOTIFIED /* notified by atomic.notify */
-                    || timeout_left <= timeout_wait /* time out */
+        }
+        else {
+            timeout_wait =
+                timeout_left < timeout_1sec ? timeout_left : timeout_1sec;
+            os_cond_reltimedwait(&wait_node->wait_cond, &node->shared_mem_lock,
+                                 timeout_wait);
+            if (wait_node->status == S_NOTIFIED /* notified by atomic.notify */
+                || timeout_left <= timeout_wait /* time out */
 #if WASM_ENABLE_THREAD_MGR != 0
-                    /* terminated by other thread */
-                    || wasm_cluster_is_thread_terminated(exec_env)
+                /* terminated by other thread */
+                || wasm_cluster_is_thread_terminated(exec_env)
 #endif
-                ) {
-                    break;
-                }
-                timeout_left -= timeout_wait;
+            ) {
+                break;
             }
+            timeout_left -= timeout_wait;
         }
     }
 
     is_timeout = wait_node->status == S_WAITING ? true : false;
-    os_mutex_unlock(&wait_node->wait_lock);
-
-    os_mutex_lock(&node->shared_mem_lock);
-    os_mutex_lock(&wait_info->wait_list_lock);
 
     check_ret = is_wait_node_exists(wait_info->wait_list, wait_node);
     bh_assert(check_ret);
+    (void)check_ret;
 
-    /* Remove wait node */
+    /* Remove wait node from wait list */
     bh_list_remove(wait_info->wait_list, wait_node);
-    os_mutex_destroy(&wait_node->wait_lock);
     os_cond_destroy(&wait_node->wait_cond);
     wasm_runtime_free(wait_node);
 
-    /* Release wait info if no wait nodes attached */
-    os_mutex_unlock(&wait_info->wait_list_lock);
+    /* Release wait info if no wait nodes are attached */
     map_try_release_wait_info(wait_map, wait_info, address);
+
     os_mutex_unlock(&node->shared_mem_lock);
 
-    (void)check_ret;
-    return no_wait ? 1 : is_timeout ? 2 : 0;
+    return is_timeout ? 2 : 0;
 }
 
 uint32
@@ -516,35 +450,41 @@ wasm_runtime_atomic_notify(WASMModuleInstanceCommon *module, void *address,
     bh_assert(module->module_type == Wasm_Module_Bytecode
               || module->module_type == Wasm_Module_AoT);
 
-    node = search_module((WASMModuleCommon *)module_inst->module);
-    if (node)
-        os_mutex_lock(&node->shared_mem_lock);
     out_of_bounds =
         ((uint8 *)address < module_inst->memories[0]->memory_data
          || (uint8 *)address + 4 > module_inst->memories[0]->memory_data_end);
 
     if (out_of_bounds) {
-        if (node)
-            os_mutex_unlock(&node->shared_mem_lock);
         wasm_runtime_set_exception(module, "out of bounds memory access");
         return -1;
     }
 
+    /* Currently we have only one memory instance */
+    if (!module_inst->memories[0]->is_shared) {
+        /* Always return 0 for ushared linear memory since there is
+           no way to create a waiter on it */
+        return 0;
+    }
+
+    node = search_module((WASMModuleCommon *)module_inst->module);
+    bh_assert(node);
+
+    /* Lock the shared_mem_lock for the whole atomic notify process,
+       and use it to os_cond_signal */
+    os_mutex_lock(&node->shared_mem_lock);
+
     wait_info = acquire_wait_info(address, NULL);
 
     /* Nobody wait on this address */
     if (!wait_info) {
-        if (node)
-            os_mutex_unlock(&node->shared_mem_lock);
+        os_mutex_unlock(&node->shared_mem_lock);
         return 0;
     }
 
-    os_mutex_lock(&wait_info->wait_list_lock);
+    /* Notify each wait node in the wait list */
     notify_result = notify_wait_list(wait_info->wait_list, count);
-    os_mutex_unlock(&wait_info->wait_list_lock);
 
-    if (node)
-        os_mutex_unlock(&node->shared_mem_lock);
+    os_mutex_unlock(&node->shared_mem_lock);
 
     return notify_result;
 }

+ 6 - 7
core/iwasm/interpreter/wasm_interp_classic.c

@@ -3414,7 +3414,8 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                         ret = wasm_runtime_atomic_notify(
                             (WASMModuleInstanceCommon *)module, maddr,
                             notify_count);
-                        bh_assert((int32)ret >= 0);
+                        if (ret == (uint32)-1)
+                            goto got_exception;
 
                         PUSH_I32(ret);
                         break;
@@ -3471,7 +3472,7 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                     {
                         /* Skip the memory index */
                         frame_ip++;
-                        os_atomic_thread_fence(os_memory_order_release);
+                        os_atomic_thread_fence(os_memory_order_seq_cst);
                         break;
                     }
 
@@ -3578,7 +3579,7 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                             CHECK_BULK_MEMORY_OVERFLOW(addr + offset, 4, maddr);
                             CHECK_ATOMIC_MEMORY_ACCESS();
                             os_mutex_lock(&node->shared_mem_lock);
-                            STORE_U32(maddr, frame_sp[1]);
+                            STORE_U32(maddr, sval);
                             os_mutex_unlock(&node->shared_mem_lock);
                         }
                         break;
@@ -3619,8 +3620,7 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                             CHECK_BULK_MEMORY_OVERFLOW(addr + offset, 8, maddr);
                             CHECK_ATOMIC_MEMORY_ACCESS();
                             os_mutex_lock(&node->shared_mem_lock);
-                            PUT_I64_TO_ADDR((uint32 *)maddr,
-                                            GET_I64_FROM_ADDR(frame_sp + 1));
+                            PUT_I64_TO_ADDR((uint32 *)maddr, sval);
                             os_mutex_unlock(&node->shared_mem_lock);
                         }
                         break;
@@ -3721,9 +3721,8 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
 
                             os_mutex_lock(&node->shared_mem_lock);
                             readv = (uint64)LOAD_I64(maddr);
-                            if (readv == expect) {
+                            if (readv == expect)
                                 STORE_I64(maddr, sval);
-                            }
                             os_mutex_unlock(&node->shared_mem_lock);
                         }
                         PUSH_I64(readv);

+ 4 - 4
core/iwasm/interpreter/wasm_interp_fast.c

@@ -3252,7 +3252,8 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                         ret = wasm_runtime_atomic_notify(
                             (WASMModuleInstanceCommon *)module, maddr,
                             notify_count);
-                        bh_assert((int32)ret >= 0);
+                        if (ret == (uint32)-1)
+                            goto got_exception;
 
                         PUSH_I32(ret);
                         break;
@@ -3307,7 +3308,7 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
                     }
                     case WASM_OP_ATOMIC_FENCE:
                     {
-                        os_atomic_thread_fence(os_memory_order_release);
+                        os_atomic_thread_fence(os_memory_order_seq_cst);
                         break;
                     }
 
@@ -3555,9 +3556,8 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module,
 
                             os_mutex_lock(&node->shared_mem_lock);
                             readv = (uint64)LOAD_I64(maddr);
-                            if (readv == expect) {
+                            if (readv == expect)
                                 STORE_I64(maddr, sval);
-                            }
                             os_mutex_unlock(&node->shared_mem_lock);
                         }
                         PUSH_I64(readv);

+ 2 - 0
core/shared/platform/include/platform_api_extension.h

@@ -121,7 +121,9 @@ os_thread_exit(void *retval);
 
 #if defined(BH_HAS_STD_ATOMIC) && !defined(__cplusplus)
 #include <stdatomic.h>
+#define os_memory_order_acquire memory_order_acquire
 #define os_memory_order_release memory_order_release
+#define os_memory_order_seq_cst memory_order_seq_cst
 #define os_atomic_thread_fence atomic_thread_fence
 #endif
 

+ 2 - 0
core/shared/platform/linux-sgx/platform_internal.h

@@ -63,7 +63,9 @@ os_set_print_function(os_print_function_t pf);
 char *
 strcpy(char *dest, const char *src);
 
+#define os_memory_order_acquire __ATOMIC_ACQUIRE
 #define os_memory_order_release __ATOMIC_RELEASE
+#define os_memory_order_seq_cst __ATOMIC_SEQ_CST
 #define os_atomic_thread_fence __atomic_thread_fence
 
 #ifdef __cplusplus