micro_interpreter.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/common.h"
  18. #include "tensorflow/lite/core/api/error_reporter.h"
  19. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  20. #include "tensorflow/lite/micro/micro_allocator.h"
  21. #include "tensorflow/lite/micro/micro_op_resolver.h"
  22. #include "tensorflow/lite/schema/schema_generated.h"
  23. #include "tensorflow/lite/type_to_tflitetype.h"
  24. namespace tflite {
  25. namespace internal {
  26. // A helper class to encapsulate the implementation of APIs in Context.
  27. // context->impl_ points to an instance of this class.
  28. // Check tensorflow/lite/c/common.h for detailed descriptions.
  29. class ContextHelper {
  30. public:
  31. explicit ContextHelper(ErrorReporter* error_reporter,
  32. MicroAllocator* allocator)
  33. : allocator_(allocator), error_reporter_(error_reporter) {}
  34. static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes,
  35. void** ptr);
  36. static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
  37. size_t bytes,
  38. int* buffer_idx);
  39. static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
  40. static void ReportOpError(struct TfLiteContext* context, const char* format,
  41. ...);
  42. void SetNodeIndex(int idx) { current_node_idx_ = idx; }
  43. private:
  44. MicroAllocator* allocator_;
  45. ErrorReporter* error_reporter_;
  46. int current_node_idx_ = -1;
  47. };
  48. } // namespace internal
  49. class MicroInterpreter {
  50. public:
  51. // The lifetime of the model, op resolver, tensor arena, and error reporter
  52. // must be at least as long as that of the interpreter object, since the
  53. // interpreter may need to access them at any time. This means that you should
  54. // usually create them with the same scope as each other, for example having
  55. // them all allocated on the stack as local variables through a top-level
  56. // function.
  57. // The interpreter doesn't do any deallocation of any of the pointed-to
  58. // objects, ownership remains with the caller.
  59. MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
  60. uint8_t* tensor_arena, size_t tensor_arena_size,
  61. ErrorReporter* error_reporter);
  62. // Create an interpreter instance using an existing MicroAllocator instance.
  63. // This constructor should be used when creating an allocator that needs to
  64. // have allocation handled in more than one interpreter or for recording
  65. // allocations inside the interpreter. The lifetime of the allocator must be
  66. // as long as that of the interpreter object.
  67. MicroInterpreter(const Model* model, const MicroOpResolver* op_resolver,
  68. MicroAllocator* allocator, ErrorReporter* error_reporter);
  69. ~MicroInterpreter();
  70. // Runs through the model and allocates all necessary input, output and
  71. // intermediate tensors.
  72. TfLiteStatus AllocateTensors();
  73. // In order to support partial graph runs for strided models, this can return
  74. // values other than kTfLiteOk and kTfLiteError.
  75. // TODO(b/149795762): Add this to the TfLiteStatus enum.
  76. TfLiteStatus Invoke();
  77. size_t tensors_size() const { return context_.tensors_size; }
  78. TfLiteTensor* tensor(size_t tensor_index);
  79. template <class T>
  80. T* typed_tensor(int tensor_index) {
  81. if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
  82. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  83. return GetTensorData<T>(tensor_ptr);
  84. }
  85. }
  86. return nullptr;
  87. }
  88. TfLiteTensor* input(size_t index);
  89. size_t inputs_size() const { return subgraph_->inputs()->Length(); }
  90. const flatbuffers::Vector<int32_t>& inputs() const {
  91. return *subgraph_->inputs();
  92. }
  93. TfLiteTensor* input_tensor(size_t index) { return input(index); }
  94. template <class T>
  95. T* typed_input_tensor(int tensor_index) {
  96. if (TfLiteTensor* tensor_ptr = input_tensor(tensor_index)) {
  97. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  98. return GetTensorData<T>(tensor_ptr);
  99. }
  100. }
  101. return nullptr;
  102. }
  103. TfLiteTensor* output(size_t index);
  104. size_t outputs_size() const { return subgraph_->outputs()->Length(); }
  105. const flatbuffers::Vector<int32_t>& outputs() const {
  106. return *subgraph_->outputs();
  107. }
  108. TfLiteTensor* output_tensor(size_t index) { return output(index); }
  109. template <class T>
  110. T* typed_output_tensor(int tensor_index) {
  111. if (TfLiteTensor* tensor_ptr = output_tensor(tensor_index)) {
  112. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  113. return GetTensorData<T>(tensor_ptr);
  114. }
  115. }
  116. return nullptr;
  117. }
  118. // Reset all variable tensors to the default value.
  119. TfLiteStatus ResetVariableTensors();
  120. TfLiteStatus initialization_status() const { return initialization_status_; }
  121. size_t operators_size() const { return subgraph_->operators()->size(); }
  122. // For debugging only.
  123. const NodeAndRegistration node_and_registration(int node_index) const {
  124. return node_and_registrations_[node_index];
  125. }
  126. // For debugging only.
  127. // Returns the actual used arena in bytes. This method gives the optimal arena
  128. // size. It's only available after `AllocateTensors` has been called.
  129. // Note that normally `tensor_arena` requires 16 bytes alignment to fully
  130. // utilize the space. If it's not the case, the optimial arena size would be
  131. // arena_used_bytes() + 16.
  132. size_t arena_used_bytes() const { return allocator_.used_bytes(); }
  133. private:
  134. // TODO(b/158263161): Consider switching to Create() function to enable better
  135. // error reporting during initialization.
  136. void Init();
  137. void CorrectTensorEndianness(TfLiteTensor* tensorCorr);
  138. template <class T>
  139. void CorrectTensorDataEndianness(T* data, int32_t size);
  140. NodeAndRegistration* node_and_registrations_ = nullptr;
  141. const Model* model_;
  142. const MicroOpResolver& op_resolver_;
  143. ErrorReporter* error_reporter_;
  144. TfLiteContext context_ = {};
  145. MicroAllocator& allocator_;
  146. bool tensors_allocated_;
  147. TfLiteStatus initialization_status_;
  148. const SubGraph* subgraph_;
  149. internal::ContextHelper context_helper_;
  150. };
  151. } // namespace tflite
  152. #endif // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_