micro_interpreter.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/micro_interpreter.h"
  13. #include <cstdarg>
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/common.h"
  18. #include "tensorflow/lite/core/api/error_reporter.h"
  19. #include "tensorflow/lite/core/api/tensor_utils.h"
  20. #include "tensorflow/lite/micro/micro_allocator.h"
  21. #include "tensorflow/lite/micro/micro_op_resolver.h"
  22. #include "tensorflow/lite/schema/schema_generated.h"
  23. namespace tflite {
  24. namespace {
  25. const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
  26. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  27. return registration->custom_name;
  28. } else {
  29. return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
  30. }
  31. }
  32. } // namespace
  33. namespace internal {
  34. TfLiteStatus ContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
  35. size_t bytes, void** ptr) {
  36. return reinterpret_cast<ContextHelper*>(ctx->impl_)
  37. ->allocator_->AllocatePersistentBuffer(bytes, ptr);
  38. }
  39. TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
  40. size_t bytes,
  41. int* buffer_idx) {
  42. ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
  43. return helper->allocator_->RequestScratchBufferInArena(
  44. helper->current_node_idx_, bytes, buffer_idx);
  45. }
  46. void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
  47. return reinterpret_cast<ContextHelper*>(ctx->impl_)
  48. ->allocator_->GetScratchBuffer(buffer_idx);
  49. }
  50. void ContextHelper::ReportOpError(struct TfLiteContext* context,
  51. const char* format, ...) {
  52. ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
  53. va_list args;
  54. va_start(args, format);
  55. TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
  56. va_end(args);
  57. }
  58. } // namespace internal
  59. MicroInterpreter::MicroInterpreter(const Model* model,
  60. const MicroOpResolver& op_resolver,
  61. uint8_t* tensor_arena,
  62. size_t tensor_arena_size,
  63. ErrorReporter* error_reporter)
  64. : model_(model),
  65. op_resolver_(op_resolver),
  66. error_reporter_(error_reporter),
  67. allocator_(*MicroAllocator::Create(&context_, model, tensor_arena,
  68. tensor_arena_size, error_reporter)),
  69. context_helper_(error_reporter_, &allocator_) {
  70. Init();
  71. }
  72. MicroInterpreter::MicroInterpreter(const Model* model,
  73. const MicroOpResolver* op_resolver,
  74. MicroAllocator* allocator,
  75. ErrorReporter* error_reporter)
  76. : model_(model),
  77. op_resolver_(*op_resolver),
  78. error_reporter_(error_reporter),
  79. allocator_(*allocator),
  80. context_helper_(error_reporter_, &allocator_) {
  81. Init();
  82. }
  83. MicroInterpreter::~MicroInterpreter() {
  84. if (node_and_registrations_ != nullptr) {
  85. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  86. TfLiteNode* node = &(node_and_registrations_[i].node);
  87. const TfLiteRegistration* registration =
  88. node_and_registrations_[i].registration;
  89. // registration is allocated outside the interpreter, so double check to
  90. // make sure it's not nullptr;
  91. if (registration != nullptr && registration->free != nullptr) {
  92. registration->free(&context_, node->user_data);
  93. }
  94. }
  95. }
  96. }
  97. void MicroInterpreter::Init() {
  98. const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs =
  99. model_->subgraphs();
  100. if (subgraphs->size() != 1) {
  101. TF_LITE_REPORT_ERROR(error_reporter_,
  102. "Only 1 subgraph is currently supported.\n");
  103. initialization_status_ = kTfLiteError;
  104. return;
  105. }
  106. subgraph_ = (*subgraphs)[0];
  107. context_.impl_ = static_cast<void*>(&context_helper_);
  108. context_.ReportError = context_helper_.ReportOpError;
  109. context_.recommended_num_threads = 1;
  110. // If the system is big endian then convert weights from the flatbuffer from
  111. // little to big endian on startup so that it does not need to be done during
  112. // inference.
  113. // NOTE: This requires that the flatbuffer is held in memory which can be
  114. // modified by this process.
  115. if (!FLATBUFFERS_LITTLEENDIAN) {
  116. for (size_t t = 0; t < tensors_size(); ++t) {
  117. TfLiteTensor* thisTensor = &context_.tensors[t];
  118. if (thisTensor->allocation_type == kTfLiteMmapRo)
  119. CorrectTensorEndianness(thisTensor);
  120. }
  121. }
  122. initialization_status_ = kTfLiteOk;
  123. }
  124. void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
  125. int32_t tensorSize = 1;
  126. for (int d = 0; d < tensorCorr->dims->size; ++d)
  127. tensorSize *= reinterpret_cast<const int32_t*>(tensorCorr->dims->data)[d];
  128. switch (tensorCorr->type) {
  129. case TfLiteType::kTfLiteFloat32:
  130. CorrectTensorDataEndianness(tensorCorr->data.f, tensorSize);
  131. break;
  132. case TfLiteType::kTfLiteFloat16:
  133. CorrectTensorDataEndianness(tensorCorr->data.f16, tensorSize);
  134. break;
  135. case TfLiteType::kTfLiteInt64:
  136. CorrectTensorDataEndianness(tensorCorr->data.i64, tensorSize);
  137. break;
  138. case TfLiteType::kTfLiteInt32:
  139. CorrectTensorDataEndianness(tensorCorr->data.i32, tensorSize);
  140. break;
  141. case TfLiteType::kTfLiteInt16:
  142. CorrectTensorDataEndianness(tensorCorr->data.i16, tensorSize);
  143. break;
  144. case TfLiteType::kTfLiteComplex64:
  145. CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize);
  146. break;
  147. default:
  148. // Do nothing for other data types.
  149. break;
  150. }
  151. }
  152. template <class T>
  153. void MicroInterpreter::CorrectTensorDataEndianness(T* data, int32_t size) {
  154. for (int32_t i = 0; i < size; ++i) {
  155. data[i] = flatbuffers::EndianScalar(data[i]);
  156. }
  157. }
  158. TfLiteStatus MicroInterpreter::AllocateTensors() {
  159. TF_LITE_ENSURE_OK(&context_, allocator_.PrepareFromFlatbuffer(
  160. op_resolver_, &node_and_registrations_));
  161. // Only allow AllocatePersistentBuffer in Init stage.
  162. context_.AllocatePersistentBuffer = context_helper_.AllocatePersistentBuffer;
  163. context_.RequestScratchBufferInArena = nullptr;
  164. context_.GetScratchBuffer = nullptr;
  165. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  166. context_helper_.SetNodeIndex(i);
  167. auto* node = &(node_and_registrations_[i].node);
  168. auto* registration = node_and_registrations_[i].registration;
  169. size_t init_data_size;
  170. const char* init_data;
  171. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  172. init_data = reinterpret_cast<const char*>(node->custom_initial_data);
  173. init_data_size = node->custom_initial_data_size;
  174. } else {
  175. init_data = reinterpret_cast<const char*>(node->builtin_data);
  176. init_data_size = 0;
  177. }
  178. if (registration->init) {
  179. node->user_data =
  180. registration->init(&context_, init_data, init_data_size);
  181. }
  182. }
  183. context_helper_.SetNodeIndex(-1);
  184. // Both AllocatePersistentBuffer and RequestScratchBufferInArena is available
  185. // in Prepare stage.
  186. context_.RequestScratchBufferInArena =
  187. context_helper_.RequestScratchBufferInArena;
  188. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  189. // Set node idx to annotate the lifetime for scratch buffers.
  190. context_helper_.SetNodeIndex(i);
  191. auto* node = &(node_and_registrations_[i].node);
  192. auto* registration = node_and_registrations_[i].registration;
  193. if (registration->prepare) {
  194. TfLiteStatus prepare_status = registration->prepare(&context_, node);
  195. if (prepare_status != kTfLiteOk) {
  196. TF_LITE_REPORT_ERROR(
  197. error_reporter_,
  198. "Node %s (number %df) failed to prepare with status %d",
  199. OpNameFromRegistration(registration), i, prepare_status);
  200. return kTfLiteError;
  201. }
  202. }
  203. }
  204. context_helper_.SetNodeIndex(-1);
  205. // Prepare is done, we're ready for Invoke. Memory allocation is no longer
  206. // allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
  207. context_.AllocatePersistentBuffer = nullptr;
  208. context_.RequestScratchBufferInArena = nullptr;
  209. context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
  210. TF_LITE_ENSURE_OK(&context_, allocator_.FinishTensorAllocation());
  211. tensors_allocated_ = true;
  212. return kTfLiteOk;
  213. }
  214. TfLiteStatus MicroInterpreter::Invoke() {
  215. if (initialization_status_ != kTfLiteOk) {
  216. TF_LITE_REPORT_ERROR(error_reporter_,
  217. "Invoke() called after initialization failed\n");
  218. return kTfLiteError;
  219. }
  220. // Ensure tensors are allocated before the interpreter is invoked to avoid
  221. // difficult to debug segfaults.
  222. if (!tensors_allocated_) {
  223. TF_LITE_ENSURE_OK(&context_, AllocateTensors());
  224. }
  225. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  226. auto* node = &(node_and_registrations_[i].node);
  227. auto* registration = node_and_registrations_[i].registration;
  228. if (registration->invoke) {
  229. TfLiteStatus invoke_status = registration->invoke(&context_, node);
  230. if (invoke_status == kTfLiteError) {
  231. TF_LITE_REPORT_ERROR(
  232. error_reporter_,
  233. "Node %s (number %d) failed to invoke with status %d",
  234. OpNameFromRegistration(registration), i, invoke_status);
  235. return kTfLiteError;
  236. } else if (invoke_status != kTfLiteOk) {
  237. return invoke_status;
  238. }
  239. }
  240. }
  241. return kTfLiteOk;
  242. }
  243. TfLiteTensor* MicroInterpreter::input(size_t index) {
  244. const size_t length = inputs_size();
  245. if ((index < 0) || (index >= length)) {
  246. TF_LITE_REPORT_ERROR(error_reporter_,
  247. "Input index %d out of range (length is %d)", index,
  248. length);
  249. return nullptr;
  250. }
  251. return &(context_.tensors[inputs().Get(index)]);
  252. }
  253. TfLiteTensor* MicroInterpreter::output(size_t index) {
  254. const size_t length = outputs_size();
  255. if ((index < 0) || (index >= length)) {
  256. TF_LITE_REPORT_ERROR(error_reporter_,
  257. "Output index %d out of range (length is %d)", index,
  258. length);
  259. return nullptr;
  260. }
  261. return &(context_.tensors[outputs().Get(index)]);
  262. }
  263. TfLiteTensor* MicroInterpreter::tensor(size_t index) {
  264. const size_t length = tensors_size();
  265. if ((index < 0) || (index >= length)) {
  266. TF_LITE_REPORT_ERROR(error_reporter_,
  267. "Tensor index %d out of range (length is %d)", index,
  268. length);
  269. return nullptr;
  270. }
  271. return &context_.tensors[index];
  272. }
  273. TfLiteStatus MicroInterpreter::ResetVariableTensors() {
  274. const size_t length = tensors_size();
  275. for (size_t i = 0; i < length; ++i) {
  276. TfLiteTensor* cur_tensor = tensor(i);
  277. if (cur_tensor->is_variable) {
  278. TfLiteStatus status = tflite::ResetVariableTensor(cur_tensor);
  279. if (status != kTfLiteOk) {
  280. TF_LITE_REPORT_ERROR(error_reporter_,
  281. "Failed to reset variable tensor at index: %d", i);
  282. return status;
  283. }
  284. }
  285. }
  286. return kTfLiteOk;
  287. }
  288. } // namespace tflite