aot_orc_extra.cpp 9.9 KB


  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. #include "llvm-c/LLJIT.h"
  6. #include "llvm-c/Orc.h"
  7. #include "llvm-c/OrcEE.h"
  8. #include "llvm-c/TargetMachine.h"
  9. #if LLVM_VERSION_MAJOR < 17
  10. #include "llvm/ADT/None.h"
  11. #include "llvm/ADT/Optional.h"
  12. #endif
  13. #include "llvm/ExecutionEngine/JITEventListener.h"
  14. #include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
  15. #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
  16. #include "llvm/ExecutionEngine/Orc/LLJIT.h"
  17. #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
  18. #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
  19. #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
  20. #include "llvm/ExecutionEngine/SectionMemoryManager.h"
  21. #include "llvm/Support/CBindingWrapping.h"
  22. #include "aot_orc_extra.h"
  23. #include "aot.h"
  24. #if LLVM_VERSION_MAJOR >= 17
  25. namespace llvm {
  26. template<typename T>
  27. using Optional = std::optional<T>;
  28. }
  29. #endif
  30. using namespace llvm;
  31. using namespace llvm::orc;
  32. using GlobalValueSet = std::set<const GlobalValue *>;
  33. namespace llvm {
  34. namespace orc {
  35. class InProgressLookupState;
  36. class OrcV2CAPIHelper
  37. {
  38. public:
  39. using PoolEntry = SymbolStringPtr::PoolEntry;
  40. using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;
  41. // Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
  42. static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S)
  43. {
  44. PoolEntryPtr Result = nullptr;
  45. std::swap(Result, S.S);
  46. return Result;
  47. }
  48. // Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
  49. static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P)
  50. {
  51. SymbolStringPtr S;
  52. S.S = P;
  53. return S;
  54. }
  55. // Copy a pool entry to a SymbolStringPtr (increments ref count).
  56. static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P)
  57. {
  58. return SymbolStringPtr(P);
  59. }
  60. static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S)
  61. {
  62. return S.S;
  63. }
  64. static void retainPoolEntry(PoolEntryPtr P)
  65. {
  66. SymbolStringPtr S(P);
  67. S.S = nullptr;
  68. }
  69. static void releasePoolEntry(PoolEntryPtr P)
  70. {
  71. SymbolStringPtr S;
  72. S.S = P;
  73. }
  74. static InProgressLookupState *extractLookupState(LookupState &LS)
  75. {
  76. return LS.IPLS.release();
  77. }
  78. static void resetLookupState(LookupState &LS, InProgressLookupState *IPLS)
  79. {
  80. return LS.reset(IPLS);
  81. }
  82. };
  83. } // namespace orc
  84. } // namespace llvm
  85. // ORC.h
  86. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
  87. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(IRTransformLayer, LLVMOrcIRTransformLayerRef)
  88. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITDylib, LLVMOrcJITDylibRef)
  89. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITTargetMachineBuilder,
  90. LLVMOrcJITTargetMachineBuilderRef)
  91. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ObjectTransformLayer,
  92. LLVMOrcObjectTransformLayerRef)
  93. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
  94. LLVMOrcSymbolStringPoolEntryRef)
  95. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ObjectLayer, LLVMOrcObjectLayerRef)
  96. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
  97. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ThreadSafeModule, LLVMOrcThreadSafeModuleRef)
  98. // LLJIT.h
  99. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLJITBuilder, LLVMOrcLLJITBuilderRef)
  100. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJITBuilder, LLVMOrcLLLazyJITBuilderRef)
  101. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJIT, LLVMOrcLLLazyJITRef)
  102. void
  103. LLVMOrcLLJITBuilderSetNumCompileThreads(LLVMOrcLLJITBuilderRef Builder,
  104. unsigned NumCompileThreads)
  105. {
  106. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  107. }
  108. LLVMOrcLLLazyJITBuilderRef
  109. LLVMOrcCreateLLLazyJITBuilder(void)
  110. {
  111. return wrap(new LLLazyJITBuilder());
  112. }
  113. void
  114. LLVMOrcDisposeLLLazyJITBuilder(LLVMOrcLLLazyJITBuilderRef Builder)
  115. {
  116. delete unwrap(Builder);
  117. }
  118. void
  119. LLVMOrcLLLazyJITBuilderSetNumCompileThreads(LLVMOrcLLLazyJITBuilderRef Builder,
  120. unsigned NumCompileThreads)
  121. {
  122. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  123. }
  124. void
  125. LLVMOrcLLLazyJITBuilderSetJITTargetMachineBuilder(
  126. LLVMOrcLLLazyJITBuilderRef Builder, LLVMOrcJITTargetMachineBuilderRef JTMP)
  127. {
  128. unwrap(Builder)->setJITTargetMachineBuilder(*unwrap(JTMP));
  129. /* Destroy the JTMP, similar to
  130. LLVMOrcLLJITBuilderSetJITTargetMachineBuilder */
  131. LLVMOrcDisposeJITTargetMachineBuilder(JTMP);
  132. }
  133. static Optional<CompileOnDemandLayer::GlobalValueSet>
  134. PartitionFunction(GlobalValueSet Requested)
  135. {
  136. std::vector<const GlobalValue *> GVsToAdd;
  137. for (auto *GV : Requested) {
  138. if (isa<Function>(GV) && GV->hasName()) {
  139. auto &F = cast<Function>(*GV); /* get LLVM function */
  140. const Module *M = F.getParent(); /* get LLVM module */
  141. auto GVName = GV->getName(); /* get the function name */
  142. const char *gvname = GVName.begin(); /* C function name */
  143. const char *wrapper;
  144. uint32 prefix_len = strlen(AOT_FUNC_PREFIX);
  145. LOG_DEBUG("requested func %s", gvname);
  146. /* Convert "aot_func#n_wrapper" to "aot_func#n" */
  147. if (strstr(gvname, AOT_FUNC_PREFIX)) {
  148. char buf[16] = { 0 };
  149. char func_name[64];
  150. int group_stride, i, j;
  151. int num;
  152. /*
  153. * if the jit wrapper (which has "_wrapper" suffix in
  154. * the name) is requested, compile others in the group too.
  155. * otherwise, only compile the requested one.
  156. * (and possibly the correspondig wrapped function,
  157. * which has AOT_FUNC_INTERNAL_PREFIX.)
  158. */
  159. wrapper = strstr(gvname + prefix_len, "_wrapper");
  160. if (wrapper != NULL) {
  161. num = WASM_ORC_JIT_COMPILE_THREAD_NUM;
  162. }
  163. else {
  164. num = 1;
  165. wrapper = strchr(gvname + prefix_len, 0);
  166. }
  167. bh_assert(wrapper - (gvname + prefix_len) > 0);
  168. /* Get AOT function index */
  169. bh_memcpy_s(buf, (uint32)sizeof(buf), gvname + prefix_len,
  170. (uint32)(wrapper - (gvname + prefix_len)));
  171. i = atoi(buf);
  172. group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
  173. /* Compile some functions each time */
  174. for (j = 0; j < num; j++) {
  175. Function *F1;
  176. snprintf(func_name, sizeof(func_name), "%s%d",
  177. AOT_FUNC_PREFIX, i + j * group_stride);
  178. F1 = M->getFunction(func_name);
  179. if (F1) {
  180. LOG_DEBUG("compile func %s", func_name);
  181. GVsToAdd.push_back(cast<GlobalValue>(F1));
  182. }
  183. snprintf(func_name, sizeof(func_name), "%s%d",
  184. AOT_FUNC_INTERNAL_PREFIX, i + j * group_stride);
  185. F1 = M->getFunction(func_name);
  186. if (F1) {
  187. LOG_DEBUG("compile func %s", func_name);
  188. GVsToAdd.push_back(cast<GlobalValue>(F1));
  189. }
  190. }
  191. }
  192. }
  193. }
  194. for (auto *GV : GVsToAdd) {
  195. Requested.insert(GV);
  196. }
  197. return Requested;
  198. }
  199. LLVMErrorRef
  200. LLVMOrcCreateLLLazyJIT(LLVMOrcLLLazyJITRef *Result,
  201. LLVMOrcLLLazyJITBuilderRef Builder)
  202. {
  203. assert(Result && "Result can not be null");
  204. if (!Builder)
  205. Builder = LLVMOrcCreateLLLazyJITBuilder();
  206. auto J = unwrap(Builder)->create();
  207. LLVMOrcDisposeLLLazyJITBuilder(Builder);
  208. if (!J) {
  209. Result = nullptr;
  210. return 0;
  211. }
  212. LLLazyJIT *lazy_jit = J->release();
  213. lazy_jit->setPartitionFunction(PartitionFunction);
  214. *Result = wrap(lazy_jit);
  215. return LLVMErrorSuccess;
  216. }
  217. LLVMErrorRef
  218. LLVMOrcDisposeLLLazyJIT(LLVMOrcLLLazyJITRef J)
  219. {
  220. delete unwrap(J);
  221. return LLVMErrorSuccess;
  222. }
  223. LLVMErrorRef
  224. LLVMOrcLLLazyJITAddLLVMIRModule(LLVMOrcLLLazyJITRef J, LLVMOrcJITDylibRef JD,
  225. LLVMOrcThreadSafeModuleRef TSM)
  226. {
  227. std::unique_ptr<ThreadSafeModule> TmpTSM(unwrap(TSM));
  228. return wrap(unwrap(J)->addLazyIRModule(*unwrap(JD), std::move(*TmpTSM)));
  229. }
  230. LLVMErrorRef
  231. LLVMOrcLLLazyJITLookup(LLVMOrcLLLazyJITRef J, LLVMOrcExecutorAddress *Result,
  232. const char *Name)
  233. {
  234. assert(Result && "Result can not be null");
  235. auto Sym = unwrap(J)->lookup(Name);
  236. if (!Sym) {
  237. *Result = 0;
  238. return wrap(Sym.takeError());
  239. }
  240. #if LLVM_VERSION_MAJOR < 15
  241. *Result = Sym->getAddress();
  242. #else
  243. *Result = Sym->getValue();
  244. #endif
  245. return LLVMErrorSuccess;
  246. }
  247. LLVMOrcSymbolStringPoolEntryRef
  248. LLVMOrcLLLazyJITMangleAndIntern(LLVMOrcLLLazyJITRef J,
  249. const char *UnmangledName)
  250. {
  251. return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
  252. unwrap(J)->mangleAndIntern(UnmangledName)));
  253. }
  254. LLVMOrcJITDylibRef
  255. LLVMOrcLLLazyJITGetMainJITDylib(LLVMOrcLLLazyJITRef J)
  256. {
  257. return wrap(&unwrap(J)->getMainJITDylib());
  258. }
  259. const char *
  260. LLVMOrcLLLazyJITGetTripleString(LLVMOrcLLLazyJITRef J)
  261. {
  262. return unwrap(J)->getTargetTriple().str().c_str();
  263. }
  264. LLVMOrcExecutionSessionRef
  265. LLVMOrcLLLazyJITGetExecutionSession(LLVMOrcLLLazyJITRef J)
  266. {
  267. return wrap(&unwrap(J)->getExecutionSession());
  268. }
  269. LLVMOrcIRTransformLayerRef
  270. LLVMOrcLLLazyJITGetIRTransformLayer(LLVMOrcLLLazyJITRef J)
  271. {
  272. return wrap(&unwrap(J)->getIRTransformLayer());
  273. }
  274. LLVMOrcObjectTransformLayerRef
  275. LLVMOrcLLLazyJITGetObjTransformLayer(LLVMOrcLLLazyJITRef J)
  276. {
  277. return wrap(&unwrap(J)->getObjTransformLayer());
  278. }
  279. LLVMOrcObjectLayerRef
  280. LLVMOrcLLLazyJITGetObjLinkingLayer(LLVMOrcLLLazyJITRef J)
  281. {
  282. return wrap(&unwrap(J)->getObjLinkingLayer());
  283. }