conv.cc 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/conv.h"
  13. #include "cmsis/CMSIS/NN/Include/arm_nn_types.h"
  14. #include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h"
  15. #include "tensorflow/lite/c/builtin_op_data.h"
  16. #include "tensorflow/lite/c/common.h"
  17. #include "tensorflow/lite/kernels/internal/common.h"
  18. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  19. #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
  20. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  21. #include "tensorflow/lite/kernels/kernel_util.h"
  22. #include "tensorflow/lite/kernels/padding.h"
  23. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  24. namespace tflite {
  25. namespace ops {
  26. namespace micro {
  27. namespace conv {
  28. constexpr int kInputTensor = 0;
  29. constexpr int kFilterTensor = 1;
  30. constexpr int kBiasTensor = 2;
  31. constexpr int kOutputTensor = 0;
  32. constexpr int kMaxChannels = 256;
  33. // Conv is quantized along dimension 0:
  34. // https://www.tensorflow.org/lite/performance/quantization_spec
  35. constexpr int kConvQuantizedDimension = 0;
  36. struct OpData {
  37. TfLitePaddingValues padding;
  38. // Cached tensor zero point values for quantized operations.
  39. int32_t input_zero_point;
  40. int32_t filter_zero_point;
  41. int32_t output_zero_point;
  42. // The scaling factor from input to output (aka the 'real multiplier') can
  43. // be represented as a fixed point multiplier plus a left shift.
  44. int32_t output_multiplier;
  45. int output_shift;
  46. // Per channel output multiplier and shift.
  47. // TODO(b/141139247): Allocate these dynamically when possible.
  48. int32_t per_channel_output_multiplier[kMaxChannels];
  49. int32_t per_channel_output_shift[kMaxChannels];
  50. // The range of the fused activation layer. For example for kNone and
  51. // uint8_t these would be 0 and 255.
  52. int32_t output_activation_min;
  53. int32_t output_activation_max;
  54. // Index to buffer for optimizations if applicable.
  55. int buffer_idx;
  56. };
  57. inline PaddingType RuntimePaddingType(TfLitePadding padding) {
  58. switch (padding) {
  59. case TfLitePadding::kTfLitePaddingSame:
  60. return PaddingType::kSame;
  61. case TfLitePadding::kTfLitePaddingValid:
  62. return PaddingType::kValid;
  63. case TfLitePadding::kTfLitePaddingUnknown:
  64. default:
  65. return PaddingType::kNone;
  66. }
  67. }
  68. TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
  69. TfLiteConvParams* params, int width, int height,
  70. int filter_width, int filter_height, int out_width,
  71. int out_height, const TfLiteType data_type,
  72. OpData* data) {
  73. bool has_bias = node->inputs->size == 3;
  74. // Check number of inputs/outputs
  75. TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
  76. TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  77. // Matching GetWindowedOutputSize in TensorFlow.
  78. auto padding = params->padding;
  79. data->padding = ComputePaddingHeightWidth(
  80. params->stride_height, params->stride_width,
  81. params->dilation_height_factor, params->dilation_width_factor, height,
  82. width, filter_height, filter_width, padding, &out_height, &out_width);
  83. // Note that quantized inference requires that all tensors have their
  84. // parameters set. This is usually done during quantized training.
  85. if (data_type != kTfLiteFloat32) {
  86. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  87. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  88. const TfLiteTensor* bias =
  89. GetOptionalInputTensor(context, node, kBiasTensor);
  90. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  91. int num_channels = filter->dims->data[kConvQuantizedDimension];
  92. TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
  93. context, input, filter, bias, output, params->activation,
  94. &data->output_multiplier, &data->output_shift,
  95. &data->output_activation_min, &data->output_activation_max,
  96. data->per_channel_output_multiplier,
  97. reinterpret_cast<int*>(data->per_channel_output_shift), num_channels));
  98. }
  99. return kTfLiteOk;
  100. }
  101. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  102. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  103. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  104. }
  105. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  106. #if defined(__ARM_FEATURE_DSP) || defined(__ARM_FEATURE_MVE)
  107. int32_t buf_size = 0;
  108. TFLITE_DCHECK(node->user_data != nullptr);
  109. TFLITE_DCHECK(node->builtin_data != nullptr);
  110. auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
  111. auto* data = reinterpret_cast<OpData*>(node->user_data);
  112. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  113. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  114. const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  115. RuntimeShape input_shape = GetTensorShape(input);
  116. RuntimeShape output_shape = GetTensorShape(output);
  117. // Initialize cmsis-nn input dimensions
  118. cmsis_nn_dims input_dims;
  119. input_dims.n = MatchingDim(input_shape, 0, output_shape, 0);
  120. input_dims.h = input->dims->data[1];
  121. input_dims.w = input->dims->data[2];
  122. input_dims.c = input_shape.Dims(3);
  123. // Initialize cmsis-nn filter dimensions
  124. cmsis_nn_dims filter_dims;
  125. filter_dims.n = output_shape.Dims(3);
  126. filter_dims.h = filter->dims->data[1];
  127. filter_dims.w = filter->dims->data[2];
  128. filter_dims.c = input_dims.c;
  129. // Initialize cmsis-nn output dimensions
  130. cmsis_nn_dims output_dims;
  131. output_dims.n = input_dims.n;
  132. output_dims.h = output->dims->data[1];
  133. output_dims.w = output->dims->data[2];
  134. output_dims.c = output_shape.Dims(3);
  135. TF_LITE_ENSURE_STATUS(CalculateOpData(
  136. context, node, params, input_dims.w, input_dims.h, filter_dims.w,
  137. filter_dims.h, output_dims.w, output_dims.h, input->type, data));
  138. data->input_zero_point = input->params.zero_point;
  139. data->filter_zero_point = filter->params.zero_point;
  140. data->output_zero_point = output->params.zero_point;
  141. if (input->type == kTfLiteInt8) {
  142. // Initialize cmsis-nn convolution parameters
  143. cmsis_nn_conv_params conv_params;
  144. conv_params.input_offset = -input->params.zero_point;
  145. conv_params.output_offset = output->params.zero_point;
  146. conv_params.stride.h = params->stride_height;
  147. conv_params.stride.w = params->stride_width;
  148. conv_params.dilation.h = params->dilation_height_factor;
  149. conv_params.dilation.w = params->dilation_width_factor;
  150. conv_params.padding.h = data->padding.height;
  151. conv_params.padding.w = data->padding.width;
  152. conv_params.activation.min = data->output_activation_min;
  153. conv_params.activation.max = data->output_activation_max;
  154. buf_size = arm_convolve_wrapper_s8_get_buffer_size(
  155. &conv_params, &input_dims, &filter_dims, &output_dims);
  156. }
  157. if (buf_size > 0) {
  158. TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
  159. context, buf_size, &data->buffer_idx));
  160. } else {
  161. data->buffer_idx = -1;
  162. }
  163. #endif
  164. return kTfLiteOk;
  165. }
  166. TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
  167. TfLiteConvParams* params, const OpData& data,
  168. const TfLiteEvalTensor* input,
  169. const TfLiteEvalTensor* filter,
  170. const TfLiteEvalTensor* bias,
  171. TfLiteEvalTensor* im2col,
  172. TfLiteEvalTensor* hwcn_weights,
  173. TfLiteEvalTensor* output) {
  174. const int32_t input_offset = -data.input_zero_point;
  175. const int32_t filter_offset = -data.filter_zero_point;
  176. const int32_t output_offset = data.output_zero_point;
  177. ConvParams op_params;
  178. op_params.padding_type = RuntimePaddingType(params->padding);
  179. op_params.padding_values.width = data.padding.width;
  180. op_params.padding_values.height = data.padding.height;
  181. op_params.stride_width = params->stride_width;
  182. op_params.stride_height = params->stride_height;
  183. op_params.dilation_width_factor = params->dilation_width_factor;
  184. op_params.dilation_height_factor = params->dilation_height_factor;
  185. op_params.input_offset = input_offset;
  186. op_params.weights_offset = filter_offset;
  187. op_params.output_offset = output_offset;
  188. op_params.output_multiplier = data.output_multiplier;
  189. op_params.output_shift = -data.output_shift;
  190. op_params.quantized_activation_min = data.output_activation_min;
  191. op_params.quantized_activation_max = data.output_activation_max;
  192. reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
  193. tflite::micro::GetTensorData<uint8_t>(input),
  194. tflite::micro::GetTensorShape(filter),
  195. tflite::micro::GetTensorData<uint8_t>(filter),
  196. tflite::micro::GetTensorShape(bias),
  197. tflite::micro::GetTensorData<int32_t>(bias),
  198. tflite::micro::GetTensorShape(output),
  199. tflite::micro::GetTensorData<uint8_t>(output),
  200. tflite::micro::GetTensorShape(im2col),
  201. tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
  202. return kTfLiteOk;
  203. }
  204. TfLiteStatus EvalQuantizedPerChannel(
  205. TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params,
  206. const OpData& data, const TfLiteEvalTensor* input,
  207. const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
  208. TfLiteEvalTensor* output, TfLiteEvalTensor* im2col) {
  209. // Initialize cmsis-nn convolution parameters
  210. cmsis_nn_conv_params conv_params;
  211. conv_params.input_offset = -data.input_zero_point;
  212. conv_params.output_offset = data.output_zero_point;
  213. conv_params.stride.h = params->stride_height;
  214. conv_params.stride.w = params->stride_width;
  215. conv_params.dilation.h = params->dilation_height_factor;
  216. conv_params.dilation.w = params->dilation_width_factor;
  217. conv_params.padding.h = data.padding.height;
  218. conv_params.padding.w = data.padding.width;
  219. conv_params.activation.min = data.output_activation_min;
  220. conv_params.activation.max = data.output_activation_max;
  221. // Initialize cmsis-nn per channel quantization parameters
  222. cmsis_nn_per_channel_quant_params quant_params;
  223. quant_params.multiplier =
  224. const_cast<int32_t*>(data.per_channel_output_multiplier);
  225. quant_params.shift = const_cast<int32_t*>(data.per_channel_output_shift);
  226. #if defined(__ARM_FEATURE_DSP) || defined(__ARM_FEATURE_MVE)
  227. RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
  228. RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
  229. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  230. RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);
  231. // Consistency check.
  232. TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max);
  233. TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
  234. TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
  235. TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
  236. const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
  237. const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
  238. const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
  239. if (tflite::micro::GetTensorData<int8_t>(bias)) {
  240. TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
  241. }
  242. // Initialize cmsis-nn dimensions
  243. // Input
  244. cmsis_nn_dims input_dims;
  245. input_dims.n = batch_size;
  246. input_dims.h = input_shape.Dims(1);
  247. input_dims.w = input_shape.Dims(2);
  248. input_dims.c = input_depth;
  249. // Filter
  250. cmsis_nn_dims filter_dims;
  251. filter_dims.n = output_depth;
  252. filter_dims.h = filter_shape.Dims(1);
  253. filter_dims.w = filter_shape.Dims(2);
  254. filter_dims.c = input_depth;
  255. // Bias
  256. cmsis_nn_dims bias_dims;
  257. bias_dims.n = 1;
  258. bias_dims.h = 1;
  259. bias_dims.w = 1;
  260. bias_dims.c = output_depth;
  261. // Output
  262. cmsis_nn_dims output_dims;
  263. output_dims.n = batch_size;
  264. output_dims.h = output_shape.Dims(1);
  265. output_dims.w = output_shape.Dims(2);
  266. output_dims.c = output_depth;
  267. // Initialize cmsis-nn context
  268. cmsis_nn_context ctx;
  269. ctx.buf = nullptr;
  270. ctx.size = 0;
  271. if (data.buffer_idx > -1) {
  272. ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
  273. // Note: ctx.size is currently not used in cmsis-nn.
  274. // The buffer should be allocated in the Prepare function through
  275. // arm_convolve_wrapper_s8_get_buffer_size
  276. }
  277. // arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with
  278. // the parameters passed
  279. arm_status status = arm_convolve_wrapper_s8(
  280. &ctx, &conv_params, &quant_params, &input_dims,
  281. tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
  282. tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
  283. tflite::micro::GetTensorData<int32_t>(bias), &output_dims,
  284. tflite::micro::GetTensorData<int8_t>(output));
  285. if (status == ARM_MATH_SUCCESS) {
  286. return kTfLiteOk;
  287. } else {
  288. return kTfLiteError;
  289. }
  290. #else
  291. #pragma message( \
  292. "CMSIS-NN optimization for conv not available for this target. Using reference kernel.")
  293. ConvParams op_params;
  294. conv_params.input_offset = -data.input_zero_point;
  295. conv_params.output_offset = data.output_zero_point;
  296. op_params.stride_height = params->stride_height;
  297. op_params.stride_width = params->stride_width;
  298. op_params.dilation_height_factor = params->dilation_height_factor;
  299. op_params.dilation_width_factor = params->dilation_width_factor;
  300. op_params.padding_values.height = data.padding.height;
  301. op_params.padding_values.width = data.padding.width;
  302. op_params.quantized_activation_min = data->output_activation_min;
  303. op_params.quantized_activation_max = data->output_activation_max;
  304. reference_integer_ops::ConvPerChannel(
  305. op_params, data->per_channel_output_multiplier,
  306. data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
  307. tflite::micro::GetTensorData<int8_t>(input),
  308. tflite::micro::GetTensorShape(filter),
  309. tflite::micro::GetTensorData<int8_t>(filter),
  310. tflite::micro::GetTensorShape(bias),
  311. tflite::micro::GetTensorData<int32_t>(bias),
  312. tflite::micro::GetTensorShape(output),
  313. tflite::micro::GetTensorData<int8_t>(output));
  314. #endif
  315. return kTfLiteOk;
  316. }
  317. TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
  318. TfLiteConvParams* params, const OpData& data,
  319. const TfLiteEvalTensor* input,
  320. const TfLiteEvalTensor* filter,
  321. const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
  322. TfLiteEvalTensor* hwcn_weights,
  323. TfLiteEvalTensor* output) {
  324. float output_activation_min, output_activation_max;
  325. CalculateActivationRange(params->activation, &output_activation_min,
  326. &output_activation_max);
  327. // TODO(b/154032858): Investigate removing extra copies.
  328. ConvParams op_params;
  329. op_params.padding_type = RuntimePaddingType(params->padding);
  330. op_params.padding_values.width = data.padding.width;
  331. op_params.padding_values.height = data.padding.height;
  332. op_params.stride_width = params->stride_width;
  333. op_params.stride_height = params->stride_height;
  334. op_params.dilation_width_factor = params->dilation_width_factor;
  335. op_params.dilation_height_factor = params->dilation_height_factor;
  336. op_params.float_activation_min = output_activation_min;
  337. op_params.float_activation_max = output_activation_max;
  338. reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
  339. tflite::micro::GetTensorData<float>(input),
  340. tflite::micro::GetTensorShape(filter),
  341. tflite::micro::GetTensorData<float>(filter),
  342. tflite::micro::GetTensorShape(bias),
  343. tflite::micro::GetTensorData<float>(bias),
  344. tflite::micro::GetTensorShape(output),
  345. tflite::micro::GetTensorData<float>(output),
  346. tflite::micro::GetTensorShape(im2col),
  347. tflite::micro::GetTensorData<float>(im2col));
  348. return kTfLiteOk;
  349. }
  350. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  351. auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
  352. const TfLiteEvalTensor* input =
  353. tflite::micro::GetEvalInput(context, node, kInputTensor);
  354. const TfLiteEvalTensor* filter =
  355. tflite::micro::GetEvalInput(context, node, kFilterTensor);
  356. const TfLiteEvalTensor* bias =
  357. (NumInputs(node) == 3)
  358. ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
  359. : nullptr;
  360. TfLiteEvalTensor* output =
  361. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  362. TFLITE_DCHECK(node->user_data != nullptr);
  363. const OpData& data = *(static_cast<const OpData*>(node->user_data));
  364. switch (input->type) { // Already know in/out types are same.
  365. case kTfLiteFloat32:
  366. EvalFloat(context, node, params, data, input, filter, bias, nullptr,
  367. nullptr, output);
  368. break;
  369. case kTfLiteInt8:
  370. return EvalQuantizedPerChannel(context, node, params, data, input, filter,
  371. bias, output, nullptr);
  372. break;
  373. case kTfLiteUInt8:
  374. return EvalQuantized(context, node, params, data, input, filter, bias,
  375. nullptr, nullptr, output);
  376. break;
  377. default:
  378. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  379. TfLiteTypeGetName(input->type), input->type);
  380. return kTfLiteError;
  381. }
  382. return kTfLiteOk;
  383. }
  384. } // namespace conv
  385. TfLiteRegistration Register_CONV_2D() {
  386. return {/*init=*/conv::Init,
  387. /*free=*/nullptr,
  388. /*prepare=*/conv::Prepare,
  389. /*invoke=*/conv::Eval,
  390. /*profiling_string=*/nullptr,
  391. /*builtin_code=*/0,
  392. /*custom_name=*/nullptr,
  393. /*version=*/0};
  394. }
  395. } // namespace micro
  396. } // namespace ops
  397. } // namespace tflite