nncase.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /* Copyright 2019 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include <nncase.h>
  16. #include <kernels/k210/k210_kernels.h>
  17. #include <runtime/target_interpreter.h>
  18. #include <stdio.h>
  19. #include <cstring>
  20. #include <utils.h>
  21. using namespace nncase;
  22. using namespace nncase::runtime;
  23. #define NNCASE_DEBUG 0
  24. namespace
  25. {
  26. void kpu_upload_dma(dmac_channel_number_t dma_ch, const uint8_t *src, uint8_t *dest, size_t input_size, plic_irq_callback_t callback, void *userdata)
  27. {
  28. if (is_memory_cache((uintptr_t)src))
  29. {
  30. std::copy_n(src, input_size, dest);
  31. src -= 0x40000000;
  32. }
  33. dmac_set_irq(dma_ch, callback, userdata, 1);
  34. dmac_set_single_mode(dma_ch, (void *)src, (void *)dest, DMAC_ADDR_INCREMENT, DMAC_ADDR_INCREMENT,
  35. DMAC_MSIZE_16, DMAC_TRANS_WIDTH_64, input_size / 8);
  36. dmac_wait_done(dma_ch);
  37. }
  38. }
  39. class nncase_context
  40. {
  41. public:
  42. int load_kmodel(const uint8_t *buffer)
  43. {
  44. int ret = interpreter_.try_load_model(buffer) ? 0 : -1;
  45. uint32_t size = interpreter_.model_size(buffer);
  46. uint8_t *buffer_iomem = (uint8_t *)((uintptr_t)buffer - IOMEM);
  47. const uint8_t *buffer_cache = buffer;
  48. memcpy(buffer_iomem, buffer_cache, size);
  49. for (int i = 0; i < size; i++)
  50. {
  51. if (buffer_iomem[i] != buffer_cache[i])
  52. {
  53. printf("flush model fail:%d %x %x \n", i, buffer_iomem[i], buffer_cache[i]);
  54. while (1)
  55. ;
  56. }
  57. }
  58. return ret;
  59. }
  60. int get_output(uint32_t index, uint8_t **data, size_t *size)
  61. {
  62. if (index >= interpreter_.outputs_size())
  63. return -1;
  64. auto mem = interpreter_.memory_at<uint8_t>(interpreter_.output_at(index));
  65. *data = mem.data();
  66. *size = mem.size();
  67. return 0;
  68. }
  69. int run_kmodel(const uint8_t *src, dmac_channel_number_t dma_ch, kpu_done_callback_t done_callback, void *userdata)
  70. {
  71. done_callback_ = done_callback;
  72. userdata_ = userdata;
  73. interpreter_.dma_ch(dma_ch);
  74. auto input = interpreter_.input_at(0);
  75. auto mem = interpreter_.memory_at<uint8_t>(input);
  76. if (input.memory_type == mem_main)
  77. {
  78. std::copy(src, src + mem.size(), mem.begin());
  79. interpreter_.run(done_thunk, on_error_thunk, node_profile_thunk, this);
  80. return 0;
  81. }
  82. else if (input.memory_type == mem_k210_kpu)
  83. {
  84. auto shape = interpreter_.input_shape_at(0);
  85. kernels::k210::kpu_upload(src, mem.data(), shape);
  86. on_upload_done();
  87. return 0;
  88. }
  89. return -1;
  90. }
  91. private:
  92. void on_done()
  93. {
  94. #if NNCASE_DEBUG
  95. printf("Total: %fms\n", interpreter_.total_duration().count() / 1e6);
  96. #endif
  97. if (done_callback_)
  98. done_callback_(userdata_);
  99. }
  100. void on_upload_done()
  101. {
  102. interpreter_.run(done_thunk, on_error_thunk, node_profile_thunk, this);
  103. }
  104. static void done_thunk(void *userdata)
  105. {
  106. reinterpret_cast<nncase_context *>(userdata)->on_done();
  107. }
  108. static void on_error_thunk(const char *err, void *userdata)
  109. {
  110. #if NNCASE_DEBUG
  111. printf("Fatal: %s\n", err);
  112. #endif
  113. }
  114. static void node_profile_thunk(runtime_opcode op, std::chrono::nanoseconds duration, void *userdata)
  115. {
  116. #if NNCASE_DEBUG
  117. printf("%s: %fms\n", node_opcode_names(op).data(), duration.count() / 1e6);
  118. #endif
  119. }
  120. static int upload_done_thunk(void *userdata)
  121. {
  122. reinterpret_cast<nncase_context *>(userdata)->on_upload_done();
  123. return 0;
  124. }
  125. private:
  126. interpreter_t interpreter_;
  127. kpu_done_callback_t done_callback_;
  128. void *userdata_;
  129. };
  130. int nncase_load_kmodel(kpu_model_context_t *ctx, const uint8_t *buffer)
  131. {
  132. auto nnctx = new (std::nothrow) nncase_context();
  133. if (ctx)
  134. {
  135. ctx->is_nncase = 1;
  136. ctx->nncase_ctx = nnctx;
  137. return nnctx->load_kmodel(buffer);
  138. }
  139. else
  140. {
  141. return -1;
  142. }
  143. }
  144. int nncase_get_output(kpu_model_context_t *ctx, uint32_t index, uint8_t **data, size_t *size)
  145. {
  146. auto nnctx = reinterpret_cast<nncase_context *>(ctx->nncase_ctx);
  147. return nnctx->get_output(index, data, size);
  148. }
  149. void nncase_model_free(kpu_model_context_t *ctx)
  150. {
  151. auto nnctx = reinterpret_cast<nncase_context *>(ctx->nncase_ctx);
  152. delete nnctx;
  153. ctx->nncase_ctx = nullptr;
  154. }
  155. int nncase_run_kmodel(kpu_model_context_t *ctx, const uint8_t *src, dmac_channel_number_t dma_ch, kpu_done_callback_t done_callback, void *userdata)
  156. {
  157. auto nnctx = reinterpret_cast<nncase_context *>(ctx->nncase_ctx);
  158. return nnctx->run_kmodel(src, dma_ch, done_callback, userdata);
  159. }