depthwise_conv.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/common.h"
  16. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  17. #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
  18. #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
  19. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  20. #include "tensorflow/lite/kernels/kernel_util.h"
  21. #include "tensorflow/lite/kernels/padding.h"
  22. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  23. namespace tflite {
  24. namespace ops {
  25. namespace micro {
  26. namespace depthwise_conv {
  27. namespace {
  28. constexpr int kInputTensor = 0;
  29. constexpr int kFilterTensor = 1;
  30. constexpr int kBiasTensor = 2;
  31. constexpr int kOutputTensor = 0;
  32. // Depthwise conv is quantized along dimension 3:
  33. // https://www.tensorflow.org/lite/performance/quantization_spec
  34. constexpr int kDepthwiseConvQuantizedDimension = 3;
  35. struct OpData {
  36. TfLitePaddingValues padding;
  37. // Cached tensor zero point values for quantized operations.
  38. int32_t input_zero_point;
  39. int32_t filter_zero_point;
  40. int32_t output_zero_point;
  41. // The scaling factor from input to output (aka the 'real multiplier') can
  42. // be represented as a fixed point multiplier plus a left shift.
  43. int32_t output_multiplier;
  44. int output_shift;
  45. // Per channel output multiplier and shift.
  46. int32_t* per_channel_output_multiplier;
  47. int32_t* per_channel_output_shift;
  48. // The range of the fused activation layer. For example for kNone and
  49. // uint8_t these would be 0 and 255.
  50. int32_t output_activation_min;
  51. int32_t output_activation_max;
  52. };
  53. TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
  54. TfLiteDepthwiseConvParams* params, int width,
  55. int height, int filter_width, int filter_height,
  56. const TfLiteType data_type, OpData* data) {
  57. bool has_bias = node->inputs->size == 3;
  58. // Check number of inputs/outputs
  59. TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
  60. TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  61. int unused_output_height, unused_output_width;
  62. data->padding = ComputePaddingHeightWidth(
  63. params->stride_height, params->stride_width, 1, 1, height, width,
  64. filter_height, filter_width, params->padding, &unused_output_height,
  65. &unused_output_width);
  66. // Note that quantized inference requires that all tensors have their
  67. // parameters set. This is usually done during quantized training.
  68. if (data_type != kTfLiteFloat32) {
  69. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  70. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  71. const TfLiteTensor* bias =
  72. GetOptionalInputTensor(context, node, kBiasTensor);
  73. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  74. int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
  75. return tflite::PopulateConvolutionQuantizationParams(
  76. context, input, filter, bias, output, params->activation,
  77. &data->output_multiplier, &data->output_shift,
  78. &data->output_activation_min, &data->output_activation_max,
  79. data->per_channel_output_multiplier,
  80. reinterpret_cast<int*>(data->per_channel_output_shift), num_channels);
  81. }
  82. return kTfLiteOk;
  83. }
  84. } // namespace
  85. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  86. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  87. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  88. }
  89. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  90. TFLITE_DCHECK(node->user_data != nullptr);
  91. TFLITE_DCHECK(node->builtin_data != nullptr);
  92. auto* params =
  93. reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
  94. OpData* data = static_cast<OpData*>(node->user_data);
  95. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  96. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  97. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  98. const TfLiteType data_type = input->type;
  99. int width = SizeOfDimension(input, 2);
  100. int height = SizeOfDimension(input, 1);
  101. int filter_width = SizeOfDimension(filter, 2);
  102. int filter_height = SizeOfDimension(filter, 1);
  103. // Per channel quantization is only needed for int8_t inference. For other
  104. // quantized types, only a single scale and zero point is needed.
  105. const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
  106. // Dynimically allocate per-channel quantization parameters.
  107. data->per_channel_output_multiplier =
  108. reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
  109. context, num_channels * sizeof(int32_t)));
  110. data->per_channel_output_shift =
  111. reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
  112. context, num_channels * sizeof(int32_t)));
  113. // All per-channel quantized tensors need valid zero point and scale arrays.
  114. if (input->type == kTfLiteInt8) {
  115. TF_LITE_ENSURE_EQ(context, filter->quantization.type,
  116. kTfLiteAffineQuantization);
  117. const auto* affine_quantization =
  118. reinterpret_cast<TfLiteAffineQuantization*>(
  119. filter->quantization.params);
  120. TF_LITE_ENSURE(context, affine_quantization);
  121. TF_LITE_ENSURE(context, affine_quantization->scale);
  122. TF_LITE_ENSURE(context, affine_quantization->zero_point);
  123. TF_LITE_ENSURE(
  124. context, affine_quantization->scale->size == 1 ||
  125. affine_quantization->scale->size ==
  126. filter->dims->data[kDepthwiseConvQuantizedDimension]);
  127. TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
  128. affine_quantization->zero_point->size);
  129. }
  130. TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
  131. filter_width, filter_height, data_type,
  132. data));
  133. data->input_zero_point = input->params.zero_point;
  134. data->filter_zero_point = filter->params.zero_point;
  135. data->output_zero_point = output->params.zero_point;
  136. return kTfLiteOk;
  137. }
  138. void EvalFloat(TfLiteContext* context, TfLiteNode* node,
  139. TfLiteDepthwiseConvParams* params, const OpData& data,
  140. const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
  141. const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
  142. float output_activation_min, output_activation_max;
  143. CalculateActivationRange(params->activation, &output_activation_min,
  144. &output_activation_max);
  145. tflite::DepthwiseParams op_params;
  146. // Padding type is ignored, but still set.
  147. op_params.padding_type = PaddingType::kSame;
  148. op_params.padding_values.width = data.padding.width;
  149. op_params.padding_values.height = data.padding.height;
  150. op_params.stride_width = params->stride_width;
  151. op_params.stride_height = params->stride_height;
  152. op_params.dilation_width_factor = params->dilation_width_factor;
  153. op_params.dilation_height_factor = params->dilation_height_factor;
  154. op_params.depth_multiplier = params->depth_multiplier;
  155. op_params.float_activation_min = output_activation_min;
  156. op_params.float_activation_max = output_activation_max;
  157. tflite::reference_ops::DepthwiseConv(
  158. op_params, tflite::micro::GetTensorShape(input),
  159. tflite::micro::GetTensorData<float>(input),
  160. tflite::micro::GetTensorShape(filter),
  161. tflite::micro::GetTensorData<float>(filter),
  162. tflite::micro::GetTensorShape(bias),
  163. tflite::micro::GetTensorData<float>(bias),
  164. tflite::micro::GetTensorShape(output),
  165. tflite::micro::GetTensorData<float>(output));
  166. }
  167. void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
  168. TfLiteDepthwiseConvParams* params,
  169. const OpData& data, const TfLiteEvalTensor* input,
  170. const TfLiteEvalTensor* filter,
  171. const TfLiteEvalTensor* bias,
  172. TfLiteEvalTensor* output) {
  173. DepthwiseParams op_params;
  174. op_params.padding_type = PaddingType::kSame;
  175. op_params.padding_values.width = data.padding.width;
  176. op_params.padding_values.height = data.padding.height;
  177. op_params.stride_width = params->stride_width;
  178. op_params.stride_height = params->stride_height;
  179. op_params.dilation_width_factor = params->dilation_width_factor;
  180. op_params.dilation_height_factor = params->dilation_height_factor;
  181. op_params.depth_multiplier = params->depth_multiplier;
  182. op_params.input_offset = -data.input_zero_point;
  183. op_params.weights_offset = 0;
  184. op_params.output_offset = data.output_zero_point;
  185. // TODO(b/130439627): Use calculated value for clamping.
  186. op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
  187. op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
  188. reference_integer_ops::DepthwiseConvPerChannel(
  189. op_params, data.per_channel_output_multiplier,
  190. data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
  191. tflite::micro::GetTensorData<int8_t>(input),
  192. tflite::micro::GetTensorShape(filter),
  193. tflite::micro::GetTensorData<int8_t>(filter),
  194. tflite::micro::GetTensorShape(bias),
  195. tflite::micro::GetTensorData<int32_t>(bias),
  196. tflite::micro::GetTensorShape(output),
  197. tflite::micro::GetTensorData<int8_t>(output));
  198. }
  199. void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
  200. TfLiteDepthwiseConvParams* params, const OpData& data,
  201. const TfLiteEvalTensor* input,
  202. const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
  203. TfLiteEvalTensor* output) {
  204. const int32_t input_offset = -data.input_zero_point;
  205. const int32_t filter_offset = -data.filter_zero_point;
  206. const int32_t output_offset = data.output_zero_point;
  207. tflite::DepthwiseParams op_params;
  208. // Padding type is ignored, but still set.
  209. op_params.padding_type = PaddingType::kSame;
  210. op_params.padding_values.width = data.padding.width;
  211. op_params.padding_values.height = data.padding.height;
  212. op_params.stride_width = params->stride_width;
  213. op_params.stride_height = params->stride_height;
  214. op_params.dilation_width_factor = params->dilation_width_factor;
  215. op_params.dilation_height_factor = params->dilation_height_factor;
  216. op_params.depth_multiplier = params->depth_multiplier;
  217. op_params.quantized_activation_min = data.output_activation_min;
  218. op_params.quantized_activation_max = data.output_activation_max;
  219. op_params.input_offset = input_offset;
  220. op_params.weights_offset = filter_offset;
  221. op_params.output_offset = output_offset;
  222. op_params.output_multiplier = data.output_multiplier;
  223. // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
  224. op_params.output_shift = -data.output_shift;
  225. tflite::reference_ops::DepthwiseConv(
  226. op_params, tflite::micro::GetTensorShape(input),
  227. tflite::micro::GetTensorData<uint8_t>(input),
  228. tflite::micro::GetTensorShape(filter),
  229. tflite::micro::GetTensorData<uint8_t>(filter),
  230. tflite::micro::GetTensorShape(bias),
  231. tflite::micro::GetTensorData<int32_t>(bias),
  232. tflite::micro::GetTensorShape(output),
  233. tflite::micro::GetTensorData<uint8_t>(output));
  234. }
  235. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  236. TFLITE_DCHECK(node->user_data != nullptr);
  237. TFLITE_DCHECK(node->builtin_data != nullptr);
  238. auto* params =
  239. reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
  240. const OpData& data = *(static_cast<const OpData*>(node->user_data));
  241. TfLiteEvalTensor* output =
  242. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  243. const TfLiteEvalTensor* input =
  244. tflite::micro::GetEvalInput(context, node, kInputTensor);
  245. const TfLiteEvalTensor* filter =
  246. tflite::micro::GetEvalInput(context, node, kFilterTensor);
  247. const TfLiteEvalTensor* bias =
  248. (NumInputs(node) == 3)
  249. ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
  250. : nullptr;
  251. // TODO(aselle): Consider whether float conv and quantized conv should be
  252. // separate ops to avoid dispatch overhead here.
  253. switch (input->type) { // Already know in/out types are same.
  254. case kTfLiteFloat32:
  255. EvalFloat(context, node, params, data, input, filter, bias, output);
  256. break;
  257. case kTfLiteInt8:
  258. EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
  259. output);
  260. break;
  261. case kTfLiteUInt8:
  262. EvalQuantized(context, node, params, data, input, filter, bias, output);
  263. break;
  264. default:
  265. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  266. TfLiteTypeGetName(input->type), input->type);
  267. return kTfLiteError;
  268. }
  269. return kTfLiteOk;
  270. }
  271. } // namespace depthwise_conv
  272. TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
  273. return {/*init=*/depthwise_conv::Init,
  274. /*free=*/nullptr,
  275. /*prepare=*/depthwise_conv::Prepare,
  276. /*invoke=*/depthwise_conv::Eval,
  277. /*profiling_string=*/nullptr,
  278. /*builtin_code=*/0,
  279. /*custom_name=*/nullptr,
  280. /*version=*/0};
  281. }
  282. } // namespace micro
  283. } // namespace ops
  284. } // namespace tflite