aot_orc_extra.cpp 11 KB


  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  6. // See https://llvm.org/LICENSE.txt for license information.
  7. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  8. #include "llvm-c/LLJIT.h"
  9. #include "llvm-c/Orc.h"
  10. #include "llvm-c/OrcEE.h"
  11. #include "llvm-c/TargetMachine.h"
  12. #if LLVM_VERSION_MAJOR < 17
  13. #include "llvm/ADT/None.h"
  14. #include "llvm/ADT/Optional.h"
  15. #endif
  16. #include "llvm/ExecutionEngine/JITEventListener.h"
  17. #include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
  18. #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
  19. #include "llvm/ExecutionEngine/Orc/LLJIT.h"
  20. #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
  21. #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
  22. #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
  23. #include "llvm/ExecutionEngine/SectionMemoryManager.h"
  24. #include "llvm/Support/CBindingWrapping.h"
  25. #include "aot_orc_extra.h"
  26. #include "aot.h"
  27. #if LLVM_VERSION_MAJOR >= 17
  28. namespace llvm {
  29. template<typename T>
  30. using Optional = std::optional<T>;
  31. }
  32. #endif
  33. using namespace llvm;
  34. using namespace llvm::orc;
  35. using GlobalValueSet = std::set<const GlobalValue *>;
  36. namespace llvm {
  37. namespace orc {
  38. class InProgressLookupState;
  39. class OrcV2CAPIHelper
  40. {
  41. public:
  42. #if LLVM_VERSION_MAJOR < 18
  43. using PoolEntry = SymbolStringPtr::PoolEntry;
  44. using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;
  45. // Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
  46. static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S)
  47. {
  48. PoolEntryPtr Result = nullptr;
  49. std::swap(Result, S.S);
  50. return Result;
  51. }
  52. // Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
  53. static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P)
  54. {
  55. SymbolStringPtr S;
  56. S.S = P;
  57. return S;
  58. }
  59. // Copy a pool entry to a SymbolStringPtr (increments ref count).
  60. static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P)
  61. {
  62. return SymbolStringPtr(P);
  63. }
  64. static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S)
  65. {
  66. return S.S;
  67. }
  68. static void retainPoolEntry(PoolEntryPtr P)
  69. {
  70. SymbolStringPtr S(P);
  71. S.S = nullptr;
  72. }
  73. static void releasePoolEntry(PoolEntryPtr P)
  74. {
  75. SymbolStringPtr S;
  76. S.S = P;
  77. }
  78. #endif
  79. static InProgressLookupState *extractLookupState(LookupState &LS)
  80. {
  81. return LS.IPLS.release();
  82. }
  83. static void resetLookupState(LookupState &LS, InProgressLookupState *IPLS)
  84. {
  85. return LS.reset(IPLS);
  86. }
  87. };
  88. } // namespace orc
  89. } // namespace llvm
  90. // ORC.h
  91. #if LLVM_VERSION_MAJOR >= 18
  92. inline LLVMOrcSymbolStringPoolEntryRef
  93. wrap(SymbolStringPoolEntryUnsafe E)
  94. {
  95. return reinterpret_cast<LLVMOrcSymbolStringPoolEntryRef>(E.rawPtr());
  96. }
  97. inline SymbolStringPoolEntryUnsafe
  98. unwrap(LLVMOrcSymbolStringPoolEntryRef E)
  99. {
  100. return reinterpret_cast<SymbolStringPoolEntryUnsafe::PoolEntry *>(E);
  101. }
  102. #endif
  103. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
  104. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(IRTransformLayer, LLVMOrcIRTransformLayerRef)
  105. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITDylib, LLVMOrcJITDylibRef)
  106. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITTargetMachineBuilder,
  107. LLVMOrcJITTargetMachineBuilderRef)
  108. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ObjectTransformLayer,
  109. LLVMOrcObjectTransformLayerRef)
  110. #if LLVM_VERSION_MAJOR < 18
  111. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
  112. LLVMOrcSymbolStringPoolEntryRef)
  113. #endif
  114. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ObjectLayer, LLVMOrcObjectLayerRef)
  115. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
  116. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ThreadSafeModule, LLVMOrcThreadSafeModuleRef)
  117. // LLJIT.h
  118. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLJITBuilder, LLVMOrcLLJITBuilderRef)
  119. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJITBuilder, LLVMOrcLLLazyJITBuilderRef)
  120. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJIT, LLVMOrcLLLazyJITRef)
  121. void
  122. LLVMOrcLLJITBuilderSetNumCompileThreads(LLVMOrcLLJITBuilderRef Builder,
  123. unsigned NumCompileThreads)
  124. {
  125. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  126. }
  127. LLVMOrcLLLazyJITBuilderRef
  128. LLVMOrcCreateLLLazyJITBuilder(void)
  129. {
  130. return wrap(new LLLazyJITBuilder());
  131. }
  132. void
  133. LLVMOrcDisposeLLLazyJITBuilder(LLVMOrcLLLazyJITBuilderRef Builder)
  134. {
  135. delete unwrap(Builder);
  136. }
  137. void
  138. LLVMOrcLLLazyJITBuilderSetNumCompileThreads(LLVMOrcLLLazyJITBuilderRef Builder,
  139. unsigned NumCompileThreads)
  140. {
  141. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  142. }
  143. void
  144. LLVMOrcLLLazyJITBuilderSetJITTargetMachineBuilder(
  145. LLVMOrcLLLazyJITBuilderRef Builder, LLVMOrcJITTargetMachineBuilderRef JTMP)
  146. {
  147. unwrap(Builder)->setJITTargetMachineBuilder(*unwrap(JTMP));
  148. /* Destroy the JTMP, similar to
  149. LLVMOrcLLJITBuilderSetJITTargetMachineBuilder */
  150. LLVMOrcDisposeJITTargetMachineBuilder(JTMP);
  151. }
  152. static Optional<GlobalValueSet>
  153. PartitionFunction(GlobalValueSet Requested)
  154. {
  155. std::vector<const GlobalValue *> GVsToAdd;
  156. for (auto *GV : Requested) {
  157. if (isa<Function>(GV) && GV->hasName()) {
  158. auto &F = cast<Function>(*GV); /* get LLVM function */
  159. const Module *M = F.getParent(); /* get LLVM module */
  160. auto GVName = GV->getName(); /* get the function name */
  161. const char *gvname = GVName.begin(); /* C function name */
  162. const char *wrapper;
  163. uint32 prefix_len = (uint32)strlen(AOT_FUNC_PREFIX);
  164. LOG_DEBUG("requested func %s", gvname);
  165. /* Convert "aot_func#n_wrapper" to "aot_func#n" */
  166. if (strstr(gvname, AOT_FUNC_PREFIX)) {
  167. char buf[16] = { 0 };
  168. char func_name[64];
  169. int group_stride, i, j;
  170. int num;
  171. /*
  172. * if the jit wrapper (which has "_wrapper" suffix in
  173. * the name) is requested, compile others in the group too.
  174. * otherwise, only compile the requested one.
  175. * (and possibly the corresponding wrapped function,
  176. * which has AOT_FUNC_INTERNAL_PREFIX.)
  177. */
  178. wrapper = strstr(gvname + prefix_len, "_wrapper");
  179. if (wrapper != NULL) {
  180. num = WASM_ORC_JIT_COMPILE_THREAD_NUM;
  181. }
  182. else {
  183. num = 1;
  184. wrapper = strchr(gvname + prefix_len, 0);
  185. }
  186. bh_assert(wrapper - (gvname + prefix_len) > 0);
  187. /* Get AOT function index */
  188. bh_memcpy_s(buf, (uint32)sizeof(buf), gvname + prefix_len,
  189. (uint32)(wrapper - (gvname + prefix_len)));
  190. i = atoi(buf);
  191. group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
  192. /* Compile some functions each time */
  193. for (j = 0; j < num; j++) {
  194. Function *F1;
  195. snprintf(func_name, sizeof(func_name), "%s%d",
  196. AOT_FUNC_PREFIX, i + j * group_stride);
  197. F1 = M->getFunction(func_name);
  198. if (F1) {
  199. LOG_DEBUG("compile func %s", func_name);
  200. GVsToAdd.push_back(cast<GlobalValue>(F1));
  201. }
  202. snprintf(func_name, sizeof(func_name), "%s%d",
  203. AOT_FUNC_INTERNAL_PREFIX, i + j * group_stride);
  204. F1 = M->getFunction(func_name);
  205. if (F1) {
  206. LOG_DEBUG("compile func %s", func_name);
  207. GVsToAdd.push_back(cast<GlobalValue>(F1));
  208. }
  209. }
  210. }
  211. }
  212. }
  213. for (auto *GV : GVsToAdd) {
  214. Requested.insert(GV);
  215. }
  216. return Requested;
  217. }
  218. LLVMErrorRef
  219. LLVMOrcCreateLLLazyJIT(LLVMOrcLLLazyJITRef *Result,
  220. LLVMOrcLLLazyJITBuilderRef Builder)
  221. {
  222. assert(Result && "Result can not be null");
  223. if (!Builder)
  224. Builder = LLVMOrcCreateLLLazyJITBuilder();
  225. auto J = unwrap(Builder)->create();
  226. LLVMOrcDisposeLLLazyJITBuilder(Builder);
  227. if (!J) {
  228. Result = nullptr;
  229. return 0;
  230. }
  231. LLLazyJIT *lazy_jit = J->release();
  232. lazy_jit->setPartitionFunction(PartitionFunction);
  233. *Result = wrap(lazy_jit);
  234. return LLVMErrorSuccess;
  235. }
  236. LLVMErrorRef
  237. LLVMOrcDisposeLLLazyJIT(LLVMOrcLLLazyJITRef J)
  238. {
  239. delete unwrap(J);
  240. return LLVMErrorSuccess;
  241. }
  242. LLVMErrorRef
  243. LLVMOrcLLLazyJITAddLLVMIRModule(LLVMOrcLLLazyJITRef J, LLVMOrcJITDylibRef JD,
  244. LLVMOrcThreadSafeModuleRef TSM)
  245. {
  246. std::unique_ptr<ThreadSafeModule> TmpTSM(unwrap(TSM));
  247. return wrap(unwrap(J)->addLazyIRModule(*unwrap(JD), std::move(*TmpTSM)));
  248. }
  249. LLVMErrorRef
  250. LLVMOrcLLLazyJITLookup(LLVMOrcLLLazyJITRef J, LLVMOrcExecutorAddress *Result,
  251. const char *Name)
  252. {
  253. assert(Result && "Result can not be null");
  254. auto Sym = unwrap(J)->lookup(Name);
  255. if (!Sym) {
  256. *Result = 0;
  257. return wrap(Sym.takeError());
  258. }
  259. #if LLVM_VERSION_MAJOR < 15
  260. *Result = Sym->getAddress();
  261. #else
  262. *Result = Sym->getValue();
  263. #endif
  264. return LLVMErrorSuccess;
  265. }
  266. LLVMOrcSymbolStringPoolEntryRef
  267. LLVMOrcLLLazyJITMangleAndIntern(LLVMOrcLLLazyJITRef J,
  268. const char *UnmangledName)
  269. {
  270. #if LLVM_VERSION_MAJOR < 18
  271. return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
  272. unwrap(J)->mangleAndIntern(UnmangledName)));
  273. #else
  274. return wrap(SymbolStringPoolEntryUnsafe::take(
  275. unwrap(J)->mangleAndIntern(UnmangledName)));
  276. #endif
  277. }
  278. LLVMOrcJITDylibRef
  279. LLVMOrcLLLazyJITGetMainJITDylib(LLVMOrcLLLazyJITRef J)
  280. {
  281. return wrap(&unwrap(J)->getMainJITDylib());
  282. }
  283. const char *
  284. LLVMOrcLLLazyJITGetTripleString(LLVMOrcLLLazyJITRef J)
  285. {
  286. return unwrap(J)->getTargetTriple().str().c_str();
  287. }
  288. LLVMOrcExecutionSessionRef
  289. LLVMOrcLLLazyJITGetExecutionSession(LLVMOrcLLLazyJITRef J)
  290. {
  291. return wrap(&unwrap(J)->getExecutionSession());
  292. }
  293. LLVMOrcIRTransformLayerRef
  294. LLVMOrcLLLazyJITGetIRTransformLayer(LLVMOrcLLLazyJITRef J)
  295. {
  296. return wrap(&unwrap(J)->getIRTransformLayer());
  297. }
  298. LLVMOrcObjectTransformLayerRef
  299. LLVMOrcLLLazyJITGetObjTransformLayer(LLVMOrcLLLazyJITRef J)
  300. {
  301. return wrap(&unwrap(J)->getObjTransformLayer());
  302. }
  303. LLVMOrcObjectLayerRef
  304. LLVMOrcLLLazyJITGetObjLinkingLayer(LLVMOrcLLLazyJITRef J)
  305. {
  306. return wrap(&unwrap(J)->getObjLinkingLayer());
  307. }