test_utils.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/testing/test_utils.h"
  13. #include "tensorflow/lite/micro/simple_memory_allocator.h"
  14. namespace tflite {
  15. namespace testing {
  16. namespace {
  17. // TODO(b/141330728): Refactor out of test_utils.cc
  18. // The variables below (and the AllocatePersistentBuffer function) are only
  19. // needed for the kernel tests and benchmarks, i.e. where we do not have an
  20. // interpreter object, and the fully featured MicroAllocator.
  21. // Currently, these need to be sufficient for all the kernel_tests. If that
  22. // becomes problematic, we can investigate allowing the arena_size to be
  23. // specified for each call to PopulatContext.
  24. constexpr size_t kArenaSize = 10000;
  25. uint8_t raw_arena_[kArenaSize];
  26. SimpleMemoryAllocator* simple_memory_allocator_ = nullptr;
  27. constexpr size_t kBufferAlignment = 16;
  28. // We store the pointer to the ith scratch buffer to implement the Request/Get
  29. // ScratchBuffer API for the tests. scratch_buffers_[i] will be the ith scratch
  30. // buffer and will still be allocated from within raw_arena_.
  31. constexpr int kNumScratchBuffers = 5;
  32. uint8_t* scratch_buffers_[kNumScratchBuffers];
  33. int scratch_buffer_count_ = 0;
  34. // Note that the context parameter in this function is only needed to match the
  35. // signature of TfLiteContext::AllocatePersistentBuffer and isn't needed in the
  36. // implementation because we are assuming a single global
  37. // simple_memory_allocator_
  38. void* AllocatePersistentBuffer(TfLiteContext* context, size_t bytes) {
  39. TFLITE_DCHECK(simple_memory_allocator_ != nullptr);
  40. return simple_memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
  41. }
  42. TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context, size_t bytes,
  43. int* buffer_index) {
  44. TFLITE_DCHECK(simple_memory_allocator_ != nullptr);
  45. TFLITE_DCHECK(buffer_index != nullptr);
  46. if (scratch_buffer_count_ == kNumScratchBuffers) {
  47. TF_LITE_REPORT_ERROR(
  48. static_cast<ErrorReporter*>(context->impl_),
  49. "Exceeded the maximum number of scratch tensors allowed (%d).",
  50. kNumScratchBuffers);
  51. return kTfLiteError;
  52. }
  53. // For tests, we allocate scratch buffers from the tail and keep them around
  54. // for the lifetime of model. This means that the arena size in the tests will
  55. // be more than what we would have if the scratch buffers could share memory.
  56. scratch_buffers_[scratch_buffer_count_] =
  57. simple_memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
  58. TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr);
  59. *buffer_index = scratch_buffer_count_++;
  60. return kTfLiteOk;
  61. }
  62. void* GetScratchBuffer(TfLiteContext* context, int buffer_index) {
  63. TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers);
  64. if (buffer_index >= scratch_buffer_count_) {
  65. return nullptr;
  66. }
  67. return scratch_buffers_[buffer_index];
  68. }
  69. TfLiteTensor* GetTensor(const struct TfLiteContext* context, int subgraph_idx) {
  70. // TODO(b/160894903): Return this value from temp allocated memory.
  71. return &context->tensors[subgraph_idx];
  72. }
  73. } // namespace
  74. uint8_t F2Q(float value, float min, float max) {
  75. int32_t result = ZeroPointFromMinMax<uint8_t>(min, max) +
  76. (value / ScaleFromMinMax<uint8_t>(min, max)) + 0.5f;
  77. if (result < std::numeric_limits<uint8_t>::min()) {
  78. result = std::numeric_limits<uint8_t>::min();
  79. }
  80. if (result > std::numeric_limits<uint8_t>::max()) {
  81. result = std::numeric_limits<uint8_t>::max();
  82. }
  83. return result;
  84. }
  85. // Converts a float value into a signed eight-bit quantized value.
  86. int8_t F2QS(float value, float min, float max) {
  87. return F2Q(value, min, max) + std::numeric_limits<int8_t>::min();
  88. }
  89. int32_t F2Q32(float value, float scale) {
  90. double quantized = static_cast<double>(value / scale);
  91. if (quantized > std::numeric_limits<int32_t>::max()) {
  92. quantized = std::numeric_limits<int32_t>::max();
  93. } else if (quantized < std::numeric_limits<int32_t>::min()) {
  94. quantized = std::numeric_limits<int32_t>::min();
  95. }
  96. return static_cast<int>(quantized);
  97. }
  98. // TODO(b/141330728): Move this method elsewhere as part clean up.
  99. void PopulateContext(TfLiteTensor* tensors, int tensors_size,
  100. ErrorReporter* error_reporter, TfLiteContext* context) {
  101. simple_memory_allocator_ =
  102. SimpleMemoryAllocator::Create(error_reporter, raw_arena_, kArenaSize);
  103. TFLITE_DCHECK(simple_memory_allocator_ != nullptr);
  104. scratch_buffer_count_ = 0;
  105. context->tensors_size = tensors_size;
  106. context->tensors = tensors;
  107. context->impl_ = static_cast<void*>(error_reporter);
  108. context->GetExecutionPlan = nullptr;
  109. context->ResizeTensor = nullptr;
  110. context->ReportError = ReportOpError;
  111. context->AddTensors = nullptr;
  112. context->GetNodeAndRegistration = nullptr;
  113. context->ReplaceNodeSubsetsWithDelegateKernels = nullptr;
  114. context->recommended_num_threads = 1;
  115. context->GetExternalContext = nullptr;
  116. context->SetExternalContext = nullptr;
  117. context->GetTensor = GetTensor;
  118. context->GetEvalTensor = nullptr;
  119. context->AllocatePersistentBuffer = AllocatePersistentBuffer;
  120. context->RequestScratchBufferInArena = RequestScratchBufferInArena;
  121. context->GetScratchBuffer = GetScratchBuffer;
  122. for (int i = 0; i < tensors_size; ++i) {
  123. if (context->tensors[i].is_variable) {
  124. ResetVariableTensor(&context->tensors[i]);
  125. }
  126. }
  127. }
  128. TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims,
  129. float min, float max, bool is_variable) {
  130. TfLiteTensor result;
  131. result.type = kTfLiteUInt8;
  132. result.data.uint8 = const_cast<uint8_t*>(data);
  133. result.dims = dims;
  134. result.params = {ScaleFromMinMax<uint8_t>(min, max),
  135. ZeroPointFromMinMax<uint8_t>(min, max)};
  136. result.allocation_type = kTfLiteMemNone;
  137. result.bytes = ElementCount(*dims) * sizeof(uint8_t);
  138. result.is_variable = false;
  139. return result;
  140. }
  141. TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
  142. float min, float max, bool is_variable) {
  143. TfLiteTensor result;
  144. result.type = kTfLiteInt8;
  145. result.data.int8 = const_cast<int8_t*>(data);
  146. result.dims = dims;
  147. result.params = {ScaleFromMinMax<int8_t>(min, max),
  148. ZeroPointFromMinMax<int8_t>(min, max)};
  149. result.allocation_type = kTfLiteMemNone;
  150. result.bytes = ElementCount(*dims) * sizeof(int8_t);
  151. result.is_variable = is_variable;
  152. return result;
  153. }
  154. TfLiteTensor CreateQuantizedTensor(float* data, uint8_t* quantized_data,
  155. TfLiteIntArray* dims, bool is_variable) {
  156. TfLiteTensor result;
  157. SymmetricQuantize(data, dims, quantized_data, &result.params.scale);
  158. result.data.uint8 = quantized_data;
  159. result.type = kTfLiteUInt8;
  160. result.dims = dims;
  161. result.params.zero_point = 128;
  162. result.allocation_type = kTfLiteMemNone;
  163. result.bytes = ElementCount(*dims) * sizeof(uint8_t);
  164. result.is_variable = is_variable;
  165. return result;
  166. }
  167. TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
  168. TfLiteIntArray* dims, bool is_variable) {
  169. TfLiteTensor result;
  170. SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
  171. result.data.int8 = quantized_data;
  172. result.type = kTfLiteInt8;
  173. result.dims = dims;
  174. result.params.zero_point = 0;
  175. result.allocation_type = kTfLiteMemNone;
  176. result.bytes = ElementCount(*dims) * sizeof(int8_t);
  177. result.is_variable = is_variable;
  178. return result;
  179. }
  180. TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
  181. TfLiteIntArray* dims, bool is_variable) {
  182. TfLiteTensor result;
  183. SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
  184. result.data.i16 = quantized_data;
  185. result.type = kTfLiteInt16;
  186. result.dims = dims;
  187. result.params.zero_point = 0;
  188. result.allocation_type = kTfLiteMemNone;
  189. result.bytes = ElementCount(*dims) * sizeof(int16_t);
  190. result.is_variable = is_variable;
  191. return result;
  192. }
  193. TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims,
  194. float scale, bool is_variable) {
  195. TfLiteTensor result;
  196. result.type = kTfLiteInt32;
  197. result.data.i32 = const_cast<int32_t*>(data);
  198. result.dims = dims;
  199. // Quantized int32_t tensors always have a zero point of 0, since the range of
  200. // int32_t values is large, and because zero point costs extra cycles during
  201. // processing.
  202. result.params = {scale, 0};
  203. result.allocation_type = kTfLiteMemNone;
  204. result.bytes = ElementCount(*dims) * sizeof(int32_t);
  205. result.is_variable = is_variable;
  206. return result;
  207. }
  208. } // namespace testing
  209. } // namespace tflite