pooling.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/pooling.h"
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
  15. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  16. #include "tensorflow/lite/kernels/kernel_util.h"
  17. #include "tensorflow/lite/kernels/padding.h"
  18. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  19. namespace tflite {
  20. namespace ops {
  21. namespace micro {
  22. namespace pooling {
  23. namespace {
  24. constexpr int kInputTensor = 0;
  25. constexpr int kOutputTensor = 0;
  26. struct OpData {
  27. TfLitePaddingValues padding;
  28. int32_t activation_min;
  29. int32_t activation_max;
  30. float activation_min_f32;
  31. float activation_max_f32;
  32. };
  33. TfLiteStatus CalculateOpData(const TfLiteContext* context,
  34. const TfLitePoolParams* params,
  35. const TfLiteTensor* input,
  36. const TfLiteTensor* output, OpData* data) {
  37. // input: batch, height, width, channel
  38. int height = SizeOfDimension(input, 1);
  39. int width = SizeOfDimension(input, 2);
  40. int out_height, out_width;
  41. data->padding = ComputePaddingHeightWidth(
  42. params->stride_height, params->stride_width,
  43. /*dilation_rate_height=*/1,
  44. /*dilation_rate_width=*/1, height, width, params->filter_height,
  45. params->filter_width, params->padding, &out_height, &out_width);
  46. return kTfLiteOk;
  47. }
  48. void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node,
  49. const TfLitePoolParams* params, const OpData* data,
  50. const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
  51. PoolParams op_params;
  52. op_params.stride_height = params->stride_height;
  53. op_params.stride_width = params->stride_width;
  54. op_params.filter_height = params->filter_height;
  55. op_params.filter_width = params->filter_width;
  56. op_params.padding_values.height = data->padding.height;
  57. op_params.padding_values.width = data->padding.width;
  58. op_params.float_activation_min = data->activation_min_f32;
  59. op_params.float_activation_max = data->activation_max_f32;
  60. reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
  61. tflite::micro::GetTensorData<float>(input),
  62. tflite::micro::GetTensorShape(output),
  63. tflite::micro::GetTensorData<float>(output));
  64. }
  65. void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
  66. const TfLitePoolParams* params, const OpData* data,
  67. const TfLiteEvalTensor* input,
  68. TfLiteEvalTensor* output) {
  69. TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
  70. PoolParams op_params;
  71. op_params.stride_height = params->stride_height;
  72. op_params.stride_width = params->stride_width;
  73. op_params.filter_height = params->filter_height;
  74. op_params.filter_width = params->filter_width;
  75. op_params.padding_values.height = data->padding.height;
  76. op_params.padding_values.width = data->padding.width;
  77. op_params.quantized_activation_min = data->activation_min;
  78. op_params.quantized_activation_max = data->activation_max;
  79. if (input->type == kTfLiteUInt8) {
  80. reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
  81. tflite::micro::GetTensorData<uint8_t>(input),
  82. tflite::micro::GetTensorShape(output),
  83. tflite::micro::GetTensorData<uint8_t>(output));
  84. } else {
  85. reference_integer_ops::AveragePool(
  86. op_params, tflite::micro::GetTensorShape(input),
  87. tflite::micro::GetTensorData<int8_t>(input),
  88. tflite::micro::GetTensorShape(output),
  89. tflite::micro::GetTensorData<int8_t>(output));
  90. }
  91. }
  92. void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
  93. TfLitePoolParams* params, const OpData* data,
  94. const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
  95. tflite::PoolParams op_params;
  96. op_params.stride_height = params->stride_height;
  97. op_params.stride_width = params->stride_width;
  98. op_params.filter_height = params->filter_height;
  99. op_params.filter_width = params->filter_width;
  100. op_params.padding_values.height = data->padding.height;
  101. op_params.padding_values.width = data->padding.width;
  102. op_params.float_activation_min = data->activation_min_f32;
  103. op_params.float_activation_max = data->activation_max_f32;
  104. reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
  105. tflite::micro::GetTensorData<float>(input),
  106. tflite::micro::GetTensorShape(output),
  107. tflite::micro::GetTensorData<float>(output));
  108. }
  109. void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
  110. TfLitePoolParams* params, const OpData* data,
  111. const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
  112. tflite::PoolParams op_params;
  113. op_params.stride_height = params->stride_height;
  114. op_params.stride_width = params->stride_width;
  115. op_params.filter_height = params->filter_height;
  116. op_params.filter_width = params->filter_width;
  117. op_params.padding_values.height = data->padding.height;
  118. op_params.padding_values.width = data->padding.width;
  119. op_params.quantized_activation_min = data->activation_min;
  120. op_params.quantized_activation_max = data->activation_max;
  121. if (input->type == kTfLiteUInt8) {
  122. reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
  123. tflite::micro::GetTensorData<uint8_t>(input),
  124. tflite::micro::GetTensorShape(output),
  125. tflite::micro::GetTensorData<uint8_t>(output));
  126. } else {
  127. reference_integer_ops::MaxPool(
  128. op_params, tflite::micro::GetTensorShape(input),
  129. tflite::micro::GetTensorData<int8_t>(input),
  130. tflite::micro::GetTensorShape(output),
  131. tflite::micro::GetTensorData<int8_t>(output));
  132. }
  133. }
  134. } // namespace
  135. TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
  136. TFLITE_DCHECK(node->builtin_data != nullptr);
  137. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  138. TFLITE_DCHECK(node->user_data != nullptr);
  139. const OpData* data = static_cast<const OpData*>(node->user_data);
  140. const TfLiteEvalTensor* input =
  141. tflite::micro::GetEvalInput(context, node, kInputTensor);
  142. TfLiteEvalTensor* output =
  143. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  144. // Inputs and outputs share the same type, guaranteed by the converter.
  145. switch (input->type) {
  146. case kTfLiteFloat32:
  147. AverageEvalFloat(context, node, params, data, input, output);
  148. break;
  149. case kTfLiteUInt8:
  150. case kTfLiteInt8:
  151. AverageEvalQuantized(context, node, params, data, input, output);
  152. break;
  153. default:
  154. TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
  155. TfLiteTypeGetName(input->type));
  156. return kTfLiteError;
  157. }
  158. return kTfLiteOk;
  159. }
  160. TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
  161. TFLITE_DCHECK(node->builtin_data != nullptr);
  162. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  163. TFLITE_DCHECK(node->user_data != nullptr);
  164. const OpData* data = static_cast<const OpData*>(node->user_data);
  165. const TfLiteEvalTensor* input =
  166. tflite::micro::GetEvalInput(context, node, kInputTensor);
  167. TfLiteEvalTensor* output =
  168. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  169. switch (input->type) {
  170. case kTfLiteFloat32:
  171. MaxEvalFloat(context, node, params, data, input, output);
  172. break;
  173. case kTfLiteUInt8:
  174. case kTfLiteInt8:
  175. MaxEvalQuantized(context, node, params, data, input, output);
  176. break;
  177. default:
  178. TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
  179. TfLiteTypeGetName(input->type));
  180. return kTfLiteError;
  181. }
  182. return kTfLiteOk;
  183. }
  184. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  185. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  186. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  187. }
  188. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  189. TFLITE_DCHECK(node->builtin_data != nullptr);
  190. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  191. TFLITE_DCHECK(node->user_data != nullptr);
  192. OpData* data = static_cast<OpData*>(node->user_data);
  193. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  194. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  195. TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
  196. if (input->type == kTfLiteFloat32) {
  197. CalculateActivationRange(params->activation, &data->activation_min_f32,
  198. &data->activation_max_f32);
  199. } else if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
  200. CalculateActivationRangeQuantized(context, params->activation, output,
  201. &data->activation_min,
  202. &data->activation_max);
  203. }
  204. return kTfLiteOk;
  205. }
  206. } // namespace pooling
  207. TfLiteRegistration Register_AVERAGE_POOL_2D() {
  208. return {/*init=*/pooling::Init,
  209. /*free=*/nullptr,
  210. /*prepare=*/pooling::Prepare,
  211. /*invoke=*/pooling::AverageEval,
  212. /*profiling_string=*/nullptr,
  213. /*builtin_code=*/0,
  214. /*custom_name=*/nullptr,
  215. /*version=*/0};
  216. }
  217. TfLiteRegistration Register_MAX_POOL_2D() {
  218. return {/*init=*/pooling::Init,
  219. /*free=*/nullptr,
  220. /*prepare=*/pooling::Prepare,
  221. /*invoke=*/pooling::MaxEval,
  222. /*profiling_string=*/nullptr,
  223. /*builtin_code=*/0,
  224. /*custom_name=*/nullptr,
  225. /*version=*/0};
  226. }
  227. } // namespace micro
  228. } // namespace ops
  229. } // namespace tflite