pooling.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/pooling.h"
  13. #include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h"
  14. #include "flatbuffers/base.h" // from @flatbuffers
  15. #include "tensorflow/lite/c/builtin_op_data.h"
  16. #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/kernel_util.h"
  19. #include "tensorflow/lite/kernels/padding.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace pooling {
  25. namespace {
  26. constexpr int kInputTensor = 0;
  27. constexpr int kOutputTensor = 0;
  28. struct OpData {
  29. TfLitePaddingValues padding;
  30. // Index to buffer for optimizations if applicable.
  31. int buffer_idx;
  32. int32_t activation_min;
  33. int32_t activation_max;
  34. };
  35. TfLiteStatus CalculateOpData(TfLiteContext* context,
  36. const TfLitePoolParams* params,
  37. const TfLiteTensor* input, TfLiteTensor* output,
  38. OpData* data) {
  39. // input: batch, height, width, channel
  40. int height = SizeOfDimension(input, 1);
  41. int width = SizeOfDimension(input, 2);
  42. int out_height, out_width;
  43. data->padding = ComputePaddingHeightWidth(
  44. params->stride_height, params->stride_width,
  45. /*dilation_rate_height=*/1,
  46. /*dilation_rate_width=*/1, height, width, params->filter_height,
  47. params->filter_width, params->padding, &out_height, &out_width);
  48. if (input->type != kTfLiteFloat32) {
  49. TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
  50. context, params->activation, output, &data->activation_min,
  51. &data->activation_max));
  52. TFLITE_DCHECK_LE(data->activation_min, data->activation_max);
  53. }
  54. // Set buffer index to a reset value
  55. data->buffer_idx = -1;
  56. return kTfLiteOk;
  57. }
  58. void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node,
  59. const TfLitePoolParams* params, const OpData& data,
  60. const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
  61. float activation_min, activation_max;
  62. CalculateActivationRange(params->activation, &activation_min,
  63. &activation_max);
  64. PoolParams op_params;
  65. op_params.stride_height = params->stride_height;
  66. op_params.stride_width = params->stride_width;
  67. op_params.filter_height = params->filter_height;
  68. op_params.filter_width = params->filter_width;
  69. op_params.padding_values.height = data.padding.height;
  70. op_params.padding_values.width = data.padding.width;
  71. op_params.float_activation_min = activation_min;
  72. op_params.float_activation_max = activation_max;
  73. reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
  74. tflite::micro::GetTensorData<float>(input),
  75. tflite::micro::GetTensorShape(output),
  76. tflite::micro::GetTensorData<float>(output));
  77. }
  78. void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
  79. const TfLitePoolParams* params, const OpData& data,
  80. const TfLiteEvalTensor* input,
  81. TfLiteEvalTensor* output) {
  82. TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
  83. PoolParams op_params;
  84. op_params.stride_height = params->stride_height;
  85. op_params.stride_width = params->stride_width;
  86. op_params.filter_height = params->filter_height;
  87. op_params.filter_width = params->filter_width;
  88. op_params.padding_values.height = data.padding.height;
  89. op_params.padding_values.width = data.padding.width;
  90. op_params.quantized_activation_min = data.activation_min;
  91. op_params.quantized_activation_max = data.activation_max;
  92. if (input->type == kTfLiteUInt8) {
  93. reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
  94. tflite::micro::GetTensorData<uint8_t>(input),
  95. tflite::micro::GetTensorShape(output),
  96. tflite::micro::GetTensorData<uint8_t>(output));
  97. } else {
  98. RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
  99. TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
  100. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  101. TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
  102. const int depth = MatchingDim(input_shape, 3, output_shape, 3);
  103. cmsis_nn_dims input_dims;
  104. input_dims.n = 1;
  105. input_dims.h = input_shape.Dims(1);
  106. input_dims.w = input_shape.Dims(2);
  107. input_dims.c = depth;
  108. cmsis_nn_dims output_dims;
  109. output_dims.n = 1;
  110. output_dims.h = output_shape.Dims(1);
  111. output_dims.w = output_shape.Dims(2);
  112. output_dims.c = depth;
  113. cmsis_nn_pool_params pool_params;
  114. pool_params.stride.h = params->stride_height;
  115. pool_params.stride.w = params->stride_width;
  116. pool_params.padding.h = data.padding.height;
  117. pool_params.padding.w = data.padding.width;
  118. pool_params.activation.min = data.activation_min;
  119. pool_params.activation.max = data.activation_max;
  120. cmsis_nn_dims filter_dims;
  121. filter_dims.n = 1;
  122. filter_dims.h = params->filter_height;
  123. filter_dims.w = params->filter_width;
  124. filter_dims.c = 1;
  125. cmsis_nn_context ctx;
  126. ctx.buf = nullptr;
  127. ctx.size = 0;
  128. if (data.buffer_idx > -1) {
  129. ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
  130. }
  131. TFLITE_DCHECK_EQ(
  132. arm_avgpool_s8(&ctx, &pool_params, &input_dims,
  133. tflite::micro::GetTensorData<int8_t>(input),
  134. &filter_dims, &output_dims,
  135. tflite::micro::GetTensorData<int8_t>(output)),
  136. ARM_MATH_SUCCESS);
  137. }
  138. }
  139. void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
  140. TfLitePoolParams* params, const OpData& data,
  141. const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
  142. float activation_min, activation_max;
  143. CalculateActivationRange(params->activation, &activation_min,
  144. &activation_max);
  145. tflite::PoolParams op_params;
  146. op_params.stride_height = params->stride_height;
  147. op_params.stride_width = params->stride_width;
  148. op_params.filter_height = params->filter_height;
  149. op_params.filter_width = params->filter_width;
  150. op_params.padding_values.height = data.padding.height;
  151. op_params.padding_values.width = data.padding.width;
  152. op_params.float_activation_min = activation_min;
  153. op_params.float_activation_max = activation_max;
  154. reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
  155. tflite::micro::GetTensorData<float>(input),
  156. tflite::micro::GetTensorShape(output),
  157. tflite::micro::GetTensorData<float>(output));
  158. }
  159. void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
  160. TfLitePoolParams* params, const OpData& data,
  161. const TfLiteEvalTensor* input,
  162. TfLiteEvalTensor* output) {
  163. tflite::PoolParams op_params;
  164. op_params.stride_height = params->stride_height;
  165. op_params.stride_width = params->stride_width;
  166. op_params.filter_height = params->filter_height;
  167. op_params.filter_width = params->filter_width;
  168. op_params.padding_values.height = data.padding.height;
  169. op_params.padding_values.width = data.padding.width;
  170. op_params.quantized_activation_min = data.activation_min;
  171. op_params.quantized_activation_max = data.activation_max;
  172. reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
  173. tflite::micro::GetTensorData<uint8_t>(input),
  174. tflite::micro::GetTensorShape(output),
  175. tflite::micro::GetTensorData<uint8_t>(output));
  176. }
  177. TfLiteStatus MaxEvalInt8(TfLiteContext* context, const TfLiteNode* node,
  178. const TfLitePoolParams* params, const OpData& data,
  179. const TfLiteEvalTensor* input,
  180. TfLiteEvalTensor* output) {
  181. RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
  182. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  183. const int depth = MatchingDim(input_shape, 3, output_shape, 3);
  184. cmsis_nn_dims input_dims;
  185. input_dims.n = 1;
  186. input_dims.h = input_shape.Dims(1);
  187. input_dims.w = input_shape.Dims(2);
  188. input_dims.c = depth;
  189. cmsis_nn_dims output_dims;
  190. output_dims.n = 1;
  191. output_dims.h = output_shape.Dims(1);
  192. output_dims.w = output_shape.Dims(2);
  193. output_dims.c = depth;
  194. cmsis_nn_pool_params pool_params;
  195. pool_params.stride.h = params->stride_height;
  196. pool_params.stride.w = params->stride_width;
  197. pool_params.padding.h = data.padding.height;
  198. pool_params.padding.w = data.padding.width;
  199. pool_params.activation.min = data.activation_min;
  200. pool_params.activation.max = data.activation_max;
  201. cmsis_nn_dims filter_dims;
  202. filter_dims.n = 1;
  203. filter_dims.h = params->filter_height;
  204. filter_dims.w = params->filter_width;
  205. filter_dims.c = 1;
  206. cmsis_nn_context ctx;
  207. ctx.buf = nullptr;
  208. ctx.size = 0;
  209. if (data.buffer_idx > -1) {
  210. ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
  211. }
  212. TFLITE_DCHECK_EQ(
  213. arm_max_pool_s8(&ctx, &pool_params, &input_dims,
  214. tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
  215. &output_dims,
  216. tflite::micro::GetTensorData<int8_t>(output)),
  217. ARM_MATH_SUCCESS);
  218. return kTfLiteOk;
  219. }
  220. } // namespace
  221. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  222. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  223. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  224. }
  225. TfLiteStatus MaxPrepare(TfLiteContext* context, TfLiteNode* node) {
  226. TFLITE_DCHECK(node->user_data != nullptr);
  227. TFLITE_DCHECK(node->builtin_data != nullptr);
  228. OpData* data = static_cast<OpData*>(node->user_data);
  229. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  230. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  231. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  232. TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
  233. return kTfLiteOk;
  234. }
  235. TfLiteStatus AveragePrepare(TfLiteContext* context, TfLiteNode* node) {
  236. TFLITE_DCHECK(node->user_data != nullptr);
  237. TFLITE_DCHECK(node->builtin_data != nullptr);
  238. OpData* data = static_cast<OpData*>(node->user_data);
  239. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  240. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  241. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  242. TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
  243. if (input->type == kTfLiteInt8) {
  244. RuntimeShape input_shape = GetTensorShape(input);
  245. TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
  246. RuntimeShape output_shape = GetTensorShape(output);
  247. TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
  248. const int depth = MatchingDim(input_shape, 3, output_shape, 3);
  249. const int output_width = output_shape.Dims(2);
  250. const int32_t buffer_size =
  251. arm_avgpool_s8_get_buffer_size(output_width, depth);
  252. if (buffer_size > 0) {
  253. TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
  254. context, buffer_size, &data->buffer_idx));
  255. } else {
  256. data->buffer_idx = -1;
  257. }
  258. }
  259. return kTfLiteOk;
  260. }
  261. TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
  262. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  263. const OpData& data = *(static_cast<const OpData*>(node->user_data));
  264. const TfLiteEvalTensor* input =
  265. tflite::micro::GetEvalInput(context, node, kInputTensor);
  266. TfLiteEvalTensor* output =
  267. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  268. // Inputs and outputs share the same type, guaranteed by the converter.
  269. switch (input->type) {
  270. case kTfLiteFloat32:
  271. AverageEvalFloat(context, node, params, data, input, output);
  272. break;
  273. case kTfLiteUInt8:
  274. case kTfLiteInt8:
  275. AverageEvalQuantized(context, node, params, data, input, output);
  276. break;
  277. default:
  278. TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
  279. TfLiteTypeGetName(input->type));
  280. return kTfLiteError;
  281. }
  282. return kTfLiteOk;
  283. }
  284. TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
  285. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  286. const OpData& data = *(static_cast<const OpData*>(node->user_data));
  287. const TfLiteEvalTensor* input =
  288. tflite::micro::GetEvalInput(context, node, kInputTensor);
  289. TfLiteEvalTensor* output =
  290. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  291. switch (input->type) {
  292. case kTfLiteFloat32:
  293. MaxEvalFloat(context, node, params, data, input, output);
  294. break;
  295. case kTfLiteUInt8:
  296. MaxEvalQuantizedUInt8(context, node, params, data, input, output);
  297. break;
  298. case kTfLiteInt8:
  299. MaxEvalInt8(context, node, params, data, input, output);
  300. break;
  301. default:
  302. TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
  303. TfLiteTypeGetName(input->type));
  304. return kTfLiteError;
  305. }
  306. return kTfLiteOk;
  307. }
  308. } // namespace pooling
  309. TfLiteRegistration Register_AVERAGE_POOL_2D() {
  310. return {/*init=*/pooling::Init,
  311. /*free=*/nullptr,
  312. /*prepare=*/pooling::AveragePrepare,
  313. /*invoke=*/pooling::AverageEval,
  314. /*profiling_string=*/nullptr,
  315. /*builtin_code=*/0,
  316. /*custom_name=*/nullptr,
  317. /*version=*/0};
  318. }
  319. TfLiteRegistration Register_MAX_POOL_2D() {
  320. return {/*init=*/pooling::Init,
  321. /*free=*/nullptr,
  322. /*prepare=*/pooling::MaxPrepare,
  323. /*invoke=*/pooling::MaxEval,
  324. /*profiling_string=*/nullptr,
  325. /*builtin_code=*/0,
  326. /*custom_name=*/nullptr,
  327. /*version=*/0};
  328. }
  329. } // namespace micro
  330. } // namespace ops
  331. } // namespace tflite