concatenation.cc 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
  13. #include <cstdint>
  14. #include "tensorflow/lite/c/builtin_op_data.h"
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/kernels/internal/tensor.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/internal/types.h"
  19. #include "tensorflow/lite/kernels/kernel_util.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace concatenation {
  25. constexpr int kMaxInputNum = 10; // Maximum number of input tensors
  26. constexpr int kOutputTensor = 0;
  27. struct OpData {
  28. ConcatenationParams params;
  29. };
  30. // Handles negative axis index, coerces to positive index value.
  31. inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
  32. if (axis >= 0) {
  33. return axis;
  34. } else {
  35. return NumDimensions(output_tensor) + axis;
  36. }
  37. }
  38. // The following functions are helpers to get tensor data in the format that the
  39. // reference op implementation expects. They provide the same functionality as
  40. // class VectorOfTensors and class VectorOfQuantizedTensors in TFLite.
  41. // Gets shapes from a list of tensors.
  42. inline void GetAllInputTensorShapes(const TfLiteContext* context,
  43. const TfLiteNode* node,
  44. RuntimeShape all_shapes[kMaxInputNum]) {
  45. TFLITE_DCHECK(context != nullptr);
  46. TFLITE_DCHECK(node != nullptr);
  47. for (int i = 0; i < node->inputs->size; ++i) {
  48. const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
  49. RuntimeShape shape = tflite::micro::GetTensorShape(t);
  50. all_shapes[i].ReplaceWith(shape.DimensionsCount(), shape.DimsData());
  51. }
  52. }
  53. // Get shape pointers from a list of shapes.
  54. inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
  55. const RuntimeShape* pointers[]) {
  56. for (size_t i = 0; i < num; ++i) {
  57. pointers[i] = &shapes[i];
  58. }
  59. }
  60. // Gets data pointers from a list of tensors.
  61. template <typename T>
  62. inline void GetAllInputTensorData(const TfLiteContext* context,
  63. const TfLiteNode* node,
  64. T* all_data[kMaxInputNum]) {
  65. TFLITE_DCHECK(context != nullptr);
  66. TFLITE_DCHECK(node != nullptr);
  67. for (int i = 0; i < node->inputs->size; ++i) {
  68. const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
  69. all_data[i] = tflite::micro::GetTensorData<T>(t);
  70. }
  71. }
  72. template <typename data_type>
  73. void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
  74. // Collect the shapes and data pointer of input tensors
  75. RuntimeShape inputs_shape[kMaxInputNum];
  76. const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
  77. const data_type* inputs_data[kMaxInputNum];
  78. GetAllInputTensorShapes(context, node, inputs_shape);
  79. GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
  80. GetAllInputTensorData(context, node, inputs_data);
  81. TfLiteEvalTensor* output =
  82. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  83. TFLITE_DCHECK(node->user_data != nullptr);
  84. const OpData* data = static_cast<const OpData*>(node->user_data);
  85. reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
  86. tflite::micro::GetTensorShape(output),
  87. tflite::micro::GetTensorData<data_type>(output));
  88. }
  89. void EvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node) {
  90. // Collect the shapes and data pointer of input tensors
  91. RuntimeShape inputs_shape[kMaxInputNum];
  92. const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
  93. const uint8_t* inputs_data[kMaxInputNum];
  94. GetAllInputTensorShapes(context, node, inputs_shape);
  95. GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
  96. GetAllInputTensorData(context, node, inputs_data);
  97. TfLiteEvalTensor* output =
  98. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  99. TFLITE_DCHECK(node->user_data != nullptr);
  100. const OpData* data = static_cast<const OpData*>(node->user_data);
  101. reference_ops::ConcatenationWithScaling(
  102. data->params, inputs_shape_ptr, inputs_data,
  103. tflite::micro::GetTensorShape(output),
  104. tflite::micro::GetTensorData<uint8_t>(output));
  105. }
  106. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  107. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  108. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  109. }
  110. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  111. // This function only checks the types. Additional shape validations are
  112. // performed in the reference implementation called during Eval().
  113. const TfLiteConcatenationParams* params =
  114. reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
  115. TfLiteType input_type = GetInput(context, node, 0)->type;
  116. TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
  117. // Check activation and input type
  118. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
  119. TF_LITE_ENSURE(context,
  120. input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
  121. input_type == kTfLiteInt8 || input_type == kTfLiteInt32 ||
  122. input_type == kTfLiteInt64);
  123. // Output type must match input type
  124. TF_LITE_ENSURE_EQ(context, output_type, input_type);
  125. // This implementation does not support large number of input tensors
  126. const int num_inputs = NumInputs(node);
  127. TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);
  128. // Shapes with dimensions >4 are not yet supported with static allocation.
  129. for (int i = 0; i < num_inputs; ++i) {
  130. const TfLiteTensor* input = GetInput(context, node, i);
  131. int num_dimensions = NumDimensions(input);
  132. if (num_dimensions > 4) {
  133. TF_LITE_KERNEL_LOG(
  134. context,
  135. "Op Concatenation does not currently support num dimensions >4 "
  136. "Tensor has %d dimensions.",
  137. num_dimensions);
  138. return kTfLiteError;
  139. }
  140. }
  141. // Calculate OpData.
  142. TFLITE_DCHECK(node->user_data != nullptr);
  143. OpData* data = static_cast<OpData*>(node->user_data);
  144. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  145. switch (output_type) { // Already know in/outtypes are same.
  146. case kTfLiteFloat32:
  147. case kTfLiteInt32:
  148. case kTfLiteInt64: {
  149. data->params.axis = CalculatePositiveAxis(params->axis, output);
  150. data->params.inputs_count = node->inputs->size;
  151. break;
  152. }
  153. case kTfLiteUInt8:
  154. case kTfLiteInt8: {
  155. data->params.axis = CalculatePositiveAxis(params->axis, output);
  156. data->params.inputs_count = node->inputs->size;
  157. float* input_scales =
  158. reinterpret_cast<float*>(context->AllocatePersistentBuffer(
  159. context, node->inputs->size * sizeof(float)));
  160. int32_t* input_zero_points =
  161. reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
  162. context, node->inputs->size * sizeof(int32_t)));
  163. // Allocate persistent scale and zeropoint buffers.
  164. // Store input scale and zero point values in OpParams:
  165. for (int i = 0; i < node->inputs->size; ++i) {
  166. const TfLiteTensor* t = GetInput(context, node, i);
  167. input_scales[i] = t->params.scale;
  168. input_zero_points[i] = t->params.zero_point;
  169. }
  170. data->params.input_scale = input_scales;
  171. data->params.input_zeropoint = input_zero_points;
  172. data->params.output_zeropoint = output->params.zero_point;
  173. data->params.output_scale = output->params.scale;
  174. break;
  175. }
  176. default:
  177. TF_LITE_KERNEL_LOG(
  178. context, "Op Concatenation does not currently support Type '%s'.",
  179. TfLiteTypeGetName(output_type));
  180. return kTfLiteError;
  181. }
  182. return kTfLiteOk;
  183. }
  184. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  185. TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
  186. switch (output_type) { // Already know in/outtypes are same.
  187. case kTfLiteFloat32:
  188. EvalUnquantized<float>(context, node);
  189. break;
  190. case kTfLiteInt32:
  191. EvalUnquantized<int32_t>(context, node);
  192. break;
  193. case kTfLiteUInt8:
  194. EvalQuantizedUInt8(context, node);
  195. break;
  196. case kTfLiteInt8:
  197. EvalUnquantized<int8_t>(context, node);
  198. break;
  199. case kTfLiteInt64:
  200. EvalUnquantized<int64_t>(context, node);
  201. break;
  202. default:
  203. TF_LITE_KERNEL_LOG(
  204. context, "Op Concatenation does not currently support Type '%s'.",
  205. TfLiteTypeGetName(output_type));
  206. return kTfLiteError;
  207. }
  208. return kTfLiteOk;
  209. }
  210. } // namespace concatenation
  211. TfLiteRegistration Register_CONCATENATION() {
  212. return {/*init=*/concatenation::Init,
  213. /*free=*/nullptr,
  214. /*prepare=*/concatenation::Prepare,
  215. /*invoke=*/concatenation::Eval,
  216. /*profiling_string=*/nullptr,
  217. /*builtin_code=*/0,
  218. /*custom_name=*/nullptr,
  219. /*version=*/0};
  220. }
  221. } // namespace micro
  222. } // namespace ops
  223. } // namespace tflite