Parcourir la source

iwasm: Fix native lib cleanup after error occurs (#2443)

- If something failed after calling init_native_lib, call deinit_native_lib to clean up.
- Restructure the relevant code for better maintainability.
YAMAMOTO Takashi il y a 2 ans
Parent
commit
5780effc07
1 fichiers modifiés avec 97 ajouts et 85 suppressions
  1. 97 85
      product-mini/platforms/posix/main.c

+ 97 - 85
product-mini/platforms/posix/main.c

@@ -282,113 +282,124 @@ validate_env_str(char *env)
 #endif
 
 #if BH_HAS_DLFCN
-typedef uint32 (*get_native_lib_func)(char **p_module_name,
-                                      NativeSymbol **p_native_symbols);
-typedef int (*init_native_lib_func)(void);
-typedef void (*deinit_native_lib_func)(void);
+struct native_lib {
+    void *handle;
+
+    uint32 (*get_native_lib)(char **p_module_name,
+                             NativeSymbol **p_native_symbols);
+    int (*init_native_lib)(void);
+    void (*deinit_native_lib)(void);
 
-static uint32
-load_and_register_native_libs(const char **native_lib_list,
-                              uint32 native_lib_count,
-                              void **native_handle_list)
-{
-    uint32 i, native_handle_count = 0, n_native_symbols;
-    NativeSymbol *native_symbols;
     char *module_name;
-    void *handle;
+    NativeSymbol *native_symbols;
+    uint32 n_native_symbols;
+};
 
-    for (i = 0; i < native_lib_count; i++) {
-        /* open the native library */
-        if (!(handle = dlopen(native_lib_list[i], RTLD_NOW | RTLD_GLOBAL))
-            && !(handle = dlopen(native_lib_list[i], RTLD_LAZY))) {
-            LOG_WARNING("warning: failed to load native library %s",
-                        native_lib_list[i]);
-            continue;
-        }
+struct native_lib *
+load_native_lib(const char *name)
+{
+    struct native_lib *lib = wasm_runtime_malloc(sizeof(*lib));
+    if (lib == NULL) {
+        LOG_WARNING("warning: failed to load native library %s because of "
+                    "allocation failure",
+                    name);
+        goto fail;
+    }
+    memset(lib, 0, sizeof(*lib));
 
-        init_native_lib_func init_native_lib = dlsym(handle, "init_native_lib");
-        if (init_native_lib) {
-            int ret = init_native_lib();
-            if (ret != 0) {
-                LOG_WARNING("warning: `init_native_lib` function from native "
-                            "lib %s failed with %d",
-                            native_lib_list[i], ret);
-                dlclose(handle);
-                continue;
-            }
+    /* open the native library */
+    if (!(lib->handle = dlopen(name, RTLD_NOW | RTLD_GLOBAL))
+        && !(lib->handle = dlopen(name, RTLD_LAZY))) {
+        LOG_WARNING("warning: failed to load native library %s", name);
+        goto fail;
+    }
+
+    lib->init_native_lib = dlsym(lib->handle, "init_native_lib");
+    lib->get_native_lib = dlsym(lib->handle, "get_native_lib");
+    lib->deinit_native_lib = dlsym(lib->handle, "deinit_native_lib");
+
+    if (!lib->get_native_lib) {
+        LOG_WARNING("warning: failed to lookup `get_native_lib` function "
+                    "from native lib %s",
+                    name);
+        goto fail;
+    }
+
+    if (lib->init_native_lib) {
+        int ret = lib->init_native_lib();
+        if (ret != 0) {
+            LOG_WARNING("warning: `init_native_lib` function from native "
+                        "lib %s failed with %d",
+                        name);
+            goto fail;
         }
+    }
 
-        /* lookup get_native_lib func */
-        get_native_lib_func get_native_lib = dlsym(handle, "get_native_lib");
-        if (!get_native_lib) {
-            LOG_WARNING("warning: failed to lookup `get_native_lib` function "
-                        "from native lib %s",
-                        native_lib_list[i]);
-            dlclose(handle);
-            continue;
+    lib->n_native_symbols =
+        lib->get_native_lib(&lib->module_name, &lib->native_symbols);
+
+    /* register native symbols */
+    if (!(lib->n_native_symbols > 0 && lib->module_name && lib->native_symbols
+          && wasm_runtime_register_natives(
+              lib->module_name, lib->native_symbols, lib->n_native_symbols))) {
+        LOG_WARNING("warning: failed to register native lib %s", name);
+        if (lib->deinit_native_lib) {
+            lib->deinit_native_lib();
+        }
+        goto fail;
+    }
+    return lib;
+fail:
+    if (lib != NULL) {
+        if (lib->handle != NULL) {
+            dlclose(lib->handle);
         }
+        wasm_runtime_free(lib);
+    }
+    return NULL;
+}
 
-        n_native_symbols = get_native_lib(&module_name, &native_symbols);
+static uint32
+load_and_register_native_libs(const char **native_lib_list,
+                              uint32 native_lib_count,
+                              struct native_lib **native_lib_loaded_list)
+{
+    uint32 i, native_lib_loaded_count = 0;
 
-        /* register native symbols */
-        if (!(n_native_symbols > 0 && module_name && native_symbols
-              && wasm_runtime_register_natives(module_name, native_symbols,
-                                               n_native_symbols))) {
-            LOG_WARNING("warning: failed to register native lib %s",
-                        native_lib_list[i]);
-            dlclose(handle);
+    for (i = 0; i < native_lib_count; i++) {
+        struct native_lib *lib = load_native_lib(native_lib_list[i]);
+        if (lib == NULL) {
             continue;
         }
-
-        native_handle_list[native_handle_count++] = handle;
+        native_lib_loaded_list[native_lib_loaded_count++] = lib;
     }
 
-    return native_handle_count;
+    return native_lib_loaded_count;
 }
 
 static void
 unregister_and_unload_native_libs(uint32 native_lib_count,
-                                  void **native_handle_list)
+                                  struct native_lib **native_lib_loaded_list)
 {
-    uint32 i, n_native_symbols;
-    NativeSymbol *native_symbols;
-    char *module_name;
-    void *handle;
+    uint32 i;
 
     for (i = 0; i < native_lib_count; i++) {
-        handle = native_handle_list[i];
-
-        /* lookup get_native_lib func */
-        get_native_lib_func get_native_lib = dlsym(handle, "get_native_lib");
-        if (!get_native_lib) {
-            LOG_WARNING("warning: failed to lookup `get_native_lib` function "
-                        "from native lib %p",
-                        handle);
-            continue;
-        }
-
-        n_native_symbols = get_native_lib(&module_name, &native_symbols);
-        if (n_native_symbols == 0 || module_name == NULL
-            || native_symbols == NULL) {
-            LOG_WARNING("warning: get_native_lib returned different values for "
-                        "native lib %p",
-                        handle);
-            continue;
-        }
+        struct native_lib *lib = native_lib_loaded_list[i];
 
         /* unregister native symbols */
-        if (!wasm_runtime_unregister_natives(module_name, native_symbols)) {
-            LOG_WARNING("warning: failed to unregister native lib %p", handle);
+        if (!wasm_runtime_unregister_natives(lib->module_name,
+                                             lib->native_symbols)) {
+            LOG_WARNING("warning: failed to unregister native lib %p",
+                        lib->handle);
             continue;
         }
 
-        deinit_native_lib_func deinit_native_lib =
-            dlsym(handle, "deinit_native_lib");
-        if (deinit_native_lib) {
-            deinit_native_lib();
+        if (lib->deinit_native_lib) {
+            lib->deinit_native_lib();
         }
 
-        dlclose(handle);
+        dlclose(lib->handle);
+        wasm_runtime_free(lib);
     }
 }
 #endif /* BH_HAS_DLFCN */
@@ -525,8 +536,8 @@ main(int argc, char *argv[])
 #if BH_HAS_DLFCN
     const char *native_lib_list[8] = { NULL };
     uint32 native_lib_count = 0;
-    void *native_handle_list[8] = { NULL };
-    uint32 native_handle_count = 0;
+    struct native_lib *native_lib_loaded_list[8];
+    uint32 native_lib_loaded_count = 0;
 #endif
 #if WASM_ENABLE_DEBUG_INTERP != 0
     char *ip_addr = NULL;
@@ -810,8 +821,8 @@ main(int argc, char *argv[])
 #endif
 
 #if BH_HAS_DLFCN
-    native_handle_count = load_and_register_native_libs(
-        native_lib_list, native_lib_count, native_handle_list);
+    native_lib_loaded_count = load_and_register_native_libs(
+        native_lib_list, native_lib_count, native_lib_loaded_list);
 #endif
 
     /* load WASM byte buffer from WASM bin file */
@@ -945,7 +956,8 @@ fail2:
 fail1:
 #if BH_HAS_DLFCN
     /* unload the native libraries */
-    unregister_and_unload_native_libs(native_handle_count, native_handle_list);
+    unregister_and_unload_native_libs(native_lib_loaded_count,
+                                      native_lib_loaded_list);
 #endif
 
     /* destroy runtime environment */