interpreter.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. /* Copyright 2019 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include <cassert>
  16. #include <iostream>
  17. #include <runtime/interpreter.h>
  18. #include <runtime/kernel_registry.h>
  19. using namespace nncase;
  20. using namespace nncase::runtime;
  21. bool interpreter_base::try_load_model(const uint8_t *buffer)
  22. {
  23. auto offset = buffer;
  24. model_header_ = reinterpret_cast<const model_header *>(buffer);
  25. // Validate model
  26. if (model_header_->identifier != MODEL_IDENTIFIER || model_header_->version != MODEL_VERSION || (model_header_->target != MODEL_TARGET_CPU && model_header_->target != MODEL_TARGET_K210))
  27. return false;
  28. // Allocate buffers
  29. main_mem_.reset(new (std::nothrow) uint8_t[model_header_->main_mem]);
  30. if (!main_mem_)
  31. return false;
  32. offset += sizeof(model_header);
  33. inputs_ = { reinterpret_cast<const memory_range *>(offset), inputs_size() };
  34. offset += sizeof(memory_range) * inputs_size();
  35. input_shapes_ = { reinterpret_cast<const runtime_shape_t *>(offset), inputs_size() };
  36. offset += sizeof(runtime_shape_t) * inputs_size();
  37. outputs_ = { reinterpret_cast<const memory_range *>(offset), outputs_size() };
  38. offset += sizeof(memory_range) * outputs_size();
  39. constants_ = { offset, model_header_->constants };
  40. offset += constants_.size();
  41. node_headers_ = { reinterpret_cast<const node_header *>(offset), nodes_size() };
  42. offset += sizeof(node_header) * nodes_size();
  43. node_body_start_ = offset;
  44. return initialize();
  45. }
  46. uint32_t interpreter_base::model_size(const uint8_t *buffer)
  47. {
  48. uint32_t size = (uint32_t)(node_body_start_ - buffer);
  49. for (int i = 0; i < nodes_size(); i++)
  50. {
  51. struct node_header cnt_layer_header = node_headers_[i];
  52. ;
  53. size += cnt_layer_header.body_size;
  54. }
  55. return size;
  56. }
  57. bool interpreter_base::initialize()
  58. {
  59. return true;
  60. }
  61. void interpreter_base::run(run_callback_t callback, error_callback_t on_error, node_profile_callback_t node_profile, void *userdata)
  62. {
  63. run_callback_ = callback;
  64. on_error_ = on_error;
  65. node_profile_ = node_profile;
  66. userdata_ = userdata;
  67. cnt_node_ = 0;
  68. cnt_node_body_ = node_body_start_;
  69. total_duration_ = {};
  70. last_time_.reset();
  71. step();
  72. }
  73. interpreter_base::clock_t::time_point interpreter_base::get_now() const noexcept
  74. {
  75. return clock_t::now();
  76. }
  77. void interpreter_base::step()
  78. {
  79. auto result = kcr_done;
  80. while (result == kcr_done)
  81. {
  82. if (!last_time_)
  83. {
  84. last_time_ = get_now();
  85. }
  86. else
  87. {
  88. auto now = get_now();
  89. auto duration = now - *last_time_;
  90. total_duration_ += duration;
  91. if (node_profile_)
  92. {
  93. node_profile_(last_op_, duration, userdata_);
  94. now = get_now();
  95. last_time_ = now;
  96. }
  97. }
  98. if (cnt_node_ == nodes_size())
  99. {
  100. run_callback_(userdata_);
  101. break;
  102. }
  103. else
  104. {
  105. auto node_id = cnt_node_++;
  106. auto header = node_headers_[node_id];
  107. xtl::span<const uint8_t> body(cnt_node_body_, header.body_size);
  108. cnt_node_body_ += header.body_size;
  109. last_op_ = header.opcode;
  110. result = call_kernel(header.opcode, body, static_cast<interpreter_t &>(*this), &interpreter_base::step);
  111. if (result == kcr_error)
  112. {
  113. if (on_error_)
  114. {
  115. char buffer[256];
  116. auto name = node_opcode_names(header.opcode);
  117. if (!name.empty())
  118. std::sprintf(buffer, "error occurs in running kernel: %s", name.data());
  119. else
  120. std::sprintf(buffer, "Unknown opcode: (%d)", header.opcode);
  121. on_error_(buffer, userdata_);
  122. }
  123. break;
  124. }
  125. }
  126. }
  127. }
  128. xtl::span<uint8_t> interpreter_base::memory_at(const memory_range &range) const noexcept
  129. {
  130. uintptr_t base;
  131. switch (range.memory_type)
  132. {
  133. case mem_const:
  134. base = (uintptr_t)constants_.data();
  135. break;
  136. case mem_main:
  137. base = (uintptr_t)main_mem_.get();
  138. break;
  139. default:
  140. base = 0;
  141. assert(!"Invalid memory type");
  142. break;
  143. }
  144. return { reinterpret_cast<uint8_t *>(base + range.start), range.size };
  145. }