aot_orc_extra.cpp 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. #include "llvm-c/LLJIT.h"
  6. #include "llvm-c/Orc.h"
  7. #include "llvm-c/OrcEE.h"
  8. #include "llvm-c/TargetMachine.h"
  9. #if LLVM_VERSION_MAJOR < 17
  10. #include "llvm/ADT/None.h"
  11. #include "llvm/ADT/Optional.h"
  12. #endif
  13. #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
  14. #include "llvm/ExecutionEngine/Orc/LLJIT.h"
  15. #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
  16. #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
  17. #include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
  18. #include "llvm/ExecutionEngine/SectionMemoryManager.h"
  19. #include "llvm/Support/CBindingWrapping.h"
  20. #include "aot_orc_extra.h"
  21. #include "aot.h"
  22. #if LLVM_VERSION_MAJOR >= 17
  23. namespace llvm {
  24. template<typename T>
  25. using Optional = std::optional<T>;
  26. }
  27. #endif
  28. using namespace llvm;
  29. using namespace llvm::orc;
  30. using GlobalValueSet = std::set<const GlobalValue *>;
  31. namespace llvm {
  32. namespace orc {
  33. class InProgressLookupState;
  34. class OrcV2CAPIHelper
  35. {
  36. public:
  37. using PoolEntry = SymbolStringPtr::PoolEntry;
  38. using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;
  39. // Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
  40. static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S)
  41. {
  42. PoolEntryPtr Result = nullptr;
  43. std::swap(Result, S.S);
  44. return Result;
  45. }
  46. // Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
  47. static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P)
  48. {
  49. SymbolStringPtr S;
  50. S.S = P;
  51. return S;
  52. }
  53. // Copy a pool entry to a SymbolStringPtr (increments ref count).
  54. static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P)
  55. {
  56. return SymbolStringPtr(P);
  57. }
  58. static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S)
  59. {
  60. return S.S;
  61. }
  62. static void retainPoolEntry(PoolEntryPtr P)
  63. {
  64. SymbolStringPtr S(P);
  65. S.S = nullptr;
  66. }
  67. static void releasePoolEntry(PoolEntryPtr P)
  68. {
  69. SymbolStringPtr S;
  70. S.S = P;
  71. }
  72. static InProgressLookupState *extractLookupState(LookupState &LS)
  73. {
  74. return LS.IPLS.release();
  75. }
  76. static void resetLookupState(LookupState &LS, InProgressLookupState *IPLS)
  77. {
  78. return LS.reset(IPLS);
  79. }
  80. };
  81. } // namespace orc
  82. } // namespace llvm
  83. // ORC.h
  84. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
  85. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(IRTransformLayer, LLVMOrcIRTransformLayerRef)
  86. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITDylib, LLVMOrcJITDylibRef)
  87. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITTargetMachineBuilder,
  88. LLVMOrcJITTargetMachineBuilderRef)
  89. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ObjectTransformLayer,
  90. LLVMOrcObjectTransformLayerRef)
  91. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
  92. LLVMOrcSymbolStringPoolEntryRef)
  93. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
  94. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ThreadSafeModule, LLVMOrcThreadSafeModuleRef)
  95. // LLJIT.h
  96. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLJITBuilder, LLVMOrcLLJITBuilderRef)
  97. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJITBuilder, LLVMOrcLLLazyJITBuilderRef)
  98. DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLLazyJIT, LLVMOrcLLLazyJITRef)
  99. void
  100. LLVMOrcLLJITBuilderSetNumCompileThreads(LLVMOrcLLJITBuilderRef Builder,
  101. unsigned NumCompileThreads)
  102. {
  103. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  104. }
  105. LLVMOrcLLLazyJITBuilderRef
  106. LLVMOrcCreateLLLazyJITBuilder(void)
  107. {
  108. return wrap(new LLLazyJITBuilder());
  109. }
  110. void
  111. LLVMOrcDisposeLLLazyJITBuilder(LLVMOrcLLLazyJITBuilderRef Builder)
  112. {
  113. delete unwrap(Builder);
  114. }
  115. void
  116. LLVMOrcLLLazyJITBuilderSetNumCompileThreads(LLVMOrcLLLazyJITBuilderRef Builder,
  117. unsigned NumCompileThreads)
  118. {
  119. unwrap(Builder)->setNumCompileThreads(NumCompileThreads);
  120. }
  121. void
  122. LLVMOrcLLLazyJITBuilderSetJITTargetMachineBuilder(
  123. LLVMOrcLLLazyJITBuilderRef Builder, LLVMOrcJITTargetMachineBuilderRef JTMP)
  124. {
  125. unwrap(Builder)->setJITTargetMachineBuilder(*unwrap(JTMP));
  126. /* Destroy the JTMP, similar to
  127. LLVMOrcLLJITBuilderSetJITTargetMachineBuilder */
  128. LLVMOrcDisposeJITTargetMachineBuilder(JTMP);
  129. }
  130. static Optional<CompileOnDemandLayer::GlobalValueSet>
  131. PartitionFunction(GlobalValueSet Requested)
  132. {
  133. std::vector<const GlobalValue *> GVsToAdd;
  134. for (auto *GV : Requested) {
  135. if (isa<Function>(GV) && GV->hasName()) {
  136. auto &F = cast<Function>(*GV); /* get LLVM function */
  137. const Module *M = F.getParent(); /* get LLVM module */
  138. auto GVName = GV->getName(); /* get the function name */
  139. const char *gvname = GVName.begin(); /* C function name */
  140. const char *wrapper;
  141. uint32 prefix_len = strlen(AOT_FUNC_PREFIX);
  142. LOG_DEBUG("requested func %s", gvname);
  143. /* Convert "aot_func#n_wrapper" to "aot_func#n" */
  144. if (strstr(gvname, AOT_FUNC_PREFIX)) {
  145. char buf[16] = { 0 };
  146. char func_name[64];
  147. int group_stride, i, j;
  148. int num;
  149. /*
  150. * if the jit wrapper (which has "_wrapper" suffix in
  151. * the name) is requested, compile others in the group too.
  152. * otherwise, only compile the requested one.
  153. * (and possibly the correspondig wrapped function,
  154. * which has AOT_FUNC_INTERNAL_PREFIX.)
  155. */
  156. wrapper = strstr(gvname + prefix_len, "_wrapper");
  157. if (wrapper != NULL) {
  158. num = WASM_ORC_JIT_COMPILE_THREAD_NUM;
  159. }
  160. else {
  161. num = 1;
  162. wrapper = strchr(gvname + prefix_len, 0);
  163. }
  164. bh_assert(wrapper - (gvname + prefix_len) > 0);
  165. /* Get AOT function index */
  166. bh_memcpy_s(buf, (uint32)sizeof(buf), gvname + prefix_len,
  167. (uint32)(wrapper - (gvname + prefix_len)));
  168. i = atoi(buf);
  169. group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
  170. /* Compile some functions each time */
  171. for (j = 0; j < num; j++) {
  172. Function *F1;
  173. snprintf(func_name, sizeof(func_name), "%s%d",
  174. AOT_FUNC_PREFIX, i + j * group_stride);
  175. F1 = M->getFunction(func_name);
  176. if (F1) {
  177. LOG_DEBUG("compile func %s", func_name);
  178. GVsToAdd.push_back(cast<GlobalValue>(F1));
  179. }
  180. snprintf(func_name, sizeof(func_name), "%s%d",
  181. AOT_FUNC_INTERNAL_PREFIX, i + j * group_stride);
  182. F1 = M->getFunction(func_name);
  183. if (F1) {
  184. LOG_DEBUG("compile func %s", func_name);
  185. GVsToAdd.push_back(cast<GlobalValue>(F1));
  186. }
  187. }
  188. }
  189. }
  190. }
  191. for (auto *GV : GVsToAdd) {
  192. Requested.insert(GV);
  193. }
  194. return Requested;
  195. }
  196. LLVMErrorRef
  197. LLVMOrcCreateLLLazyJIT(LLVMOrcLLLazyJITRef *Result,
  198. LLVMOrcLLLazyJITBuilderRef Builder)
  199. {
  200. assert(Result && "Result can not be null");
  201. if (!Builder)
  202. Builder = LLVMOrcCreateLLLazyJITBuilder();
  203. auto J = unwrap(Builder)->create();
  204. LLVMOrcDisposeLLLazyJITBuilder(Builder);
  205. if (!J) {
  206. Result = nullptr;
  207. return 0;
  208. }
  209. LLLazyJIT *lazy_jit = J->release();
  210. lazy_jit->setPartitionFunction(PartitionFunction);
  211. *Result = wrap(lazy_jit);
  212. return LLVMErrorSuccess;
  213. }
  214. LLVMErrorRef
  215. LLVMOrcDisposeLLLazyJIT(LLVMOrcLLLazyJITRef J)
  216. {
  217. delete unwrap(J);
  218. return LLVMErrorSuccess;
  219. }
  220. LLVMErrorRef
  221. LLVMOrcLLLazyJITAddLLVMIRModule(LLVMOrcLLLazyJITRef J, LLVMOrcJITDylibRef JD,
  222. LLVMOrcThreadSafeModuleRef TSM)
  223. {
  224. std::unique_ptr<ThreadSafeModule> TmpTSM(unwrap(TSM));
  225. return wrap(unwrap(J)->addLazyIRModule(*unwrap(JD), std::move(*TmpTSM)));
  226. }
  227. LLVMErrorRef
  228. LLVMOrcLLLazyJITLookup(LLVMOrcLLLazyJITRef J, LLVMOrcExecutorAddress *Result,
  229. const char *Name)
  230. {
  231. assert(Result && "Result can not be null");
  232. auto Sym = unwrap(J)->lookup(Name);
  233. if (!Sym) {
  234. *Result = 0;
  235. return wrap(Sym.takeError());
  236. }
  237. #if LLVM_VERSION_MAJOR < 15
  238. *Result = Sym->getAddress();
  239. #else
  240. *Result = Sym->getValue();
  241. #endif
  242. return LLVMErrorSuccess;
  243. }
  244. LLVMOrcSymbolStringPoolEntryRef
  245. LLVMOrcLLLazyJITMangleAndIntern(LLVMOrcLLLazyJITRef J,
  246. const char *UnmangledName)
  247. {
  248. return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
  249. unwrap(J)->mangleAndIntern(UnmangledName)));
  250. }
  251. LLVMOrcJITDylibRef
  252. LLVMOrcLLLazyJITGetMainJITDylib(LLVMOrcLLLazyJITRef J)
  253. {
  254. return wrap(&unwrap(J)->getMainJITDylib());
  255. }
  256. const char *
  257. LLVMOrcLLLazyJITGetTripleString(LLVMOrcLLLazyJITRef J)
  258. {
  259. return unwrap(J)->getTargetTriple().str().c_str();
  260. }
  261. LLVMOrcExecutionSessionRef
  262. LLVMOrcLLLazyJITGetExecutionSession(LLVMOrcLLLazyJITRef J)
  263. {
  264. return wrap(&unwrap(J)->getExecutionSession());
  265. }
  266. LLVMOrcIRTransformLayerRef
  267. LLVMOrcLLLazyJITGetIRTransformLayer(LLVMOrcLLLazyJITRef J)
  268. {
  269. return wrap(&unwrap(J)->getIRTransformLayer());
  270. }
  271. LLVMOrcObjectTransformLayerRef
  272. LLVMOrcLLLazyJITGetObjTransformLayer(LLVMOrcLLLazyJITRef J)
  273. {
  274. return wrap(&unwrap(J)->getObjTransformLayer());
  275. }