concatenation.cc 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
  13. #include <cstdint>
  14. #include "tensorflow/lite/c/builtin_op_data.h"
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/kernels/internal/tensor.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/internal/types.h"
  19. #include "tensorflow/lite/kernels/kernel_util.h"
  20. namespace tflite {
  21. namespace ops {
  22. namespace micro {
  23. namespace concatenation {
  24. constexpr int kMaxInputNum = 10; // Maximum number of input tensors
  25. constexpr int kOutputTensor = 0;
  26. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  27. // This function only checks the types. Additional shape validations are
  28. // performed in the reference implementation called during Eval().
  29. const TfLiteConcatenationParams* params =
  30. reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
  31. TfLiteType input_type = GetInput(context, node, 0)->type;
  32. TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
  33. // Check activation and input type
  34. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
  35. TF_LITE_ENSURE(context,
  36. input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
  37. input_type == kTfLiteInt8 || input_type == kTfLiteInt32 ||
  38. input_type == kTfLiteInt64);
  39. // Output type must match input type
  40. TF_LITE_ENSURE_EQ(context, output_type, input_type);
  41. // This implementation does not support large number of input tensors
  42. const int num_inputs = NumInputs(node);
  43. TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);
  44. // Shapes with dimensions >4 are not yet supported with static allocation.
  45. for (int i = 0; i < num_inputs; ++i) {
  46. const TfLiteTensor* input = GetInput(context, node, i);
  47. int num_dimensions = NumDimensions(input);
  48. if (num_dimensions > 4) {
  49. TF_LITE_KERNEL_LOG(
  50. context,
  51. "Op Concatenation does not currently support num dimensions >4 "
  52. "Tensor '%s' has %d dimensions.",
  53. input->name, num_dimensions);
  54. return kTfLiteError;
  55. }
  56. }
  57. return kTfLiteOk;
  58. }
  59. // Handles negative axis index, coerces to positive index value.
  60. inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
  61. if (axis >= 0) {
  62. return axis;
  63. } else {
  64. return NumDimensions(output_tensor) + axis;
  65. }
  66. }
  67. // The following functions are helpers to get tensor data in the format that the
  68. // reference op implementation expects. They provide the same functionality as
  69. // class VectorOfTensors and class VectorOfQuantizedTensors in TFLite.
  70. // Gets shapes from a list of tensors.
  71. inline void GetAllTensorShapes(const TfLiteContext& context,
  72. const TfLiteIntArray& tensor_list,
  73. RuntimeShape all_shapes[kMaxInputNum]) {
  74. for (int i = 0; i < tensor_list.size; ++i) {
  75. const TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
  76. RuntimeShape shape = GetTensorShape(t);
  77. all_shapes[i].ReplaceWith(shape.DimensionsCount(), shape.DimsData());
  78. }
  79. }
  80. // Get shape pointers from a list of shapes.
  81. inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
  82. const RuntimeShape* pointers[]) {
  83. for (size_t i = 0; i < num; ++i) {
  84. pointers[i] = &shapes[i];
  85. }
  86. }
  87. // Gets data pointers from a list of tensors.
  88. template <typename T>
  89. inline void GetAllTensorData(const TfLiteContext& context,
  90. const TfLiteIntArray& tensor_list,
  91. T* all_data[kMaxInputNum]) {
  92. for (int i = 0; i < tensor_list.size; ++i) {
  93. const TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
  94. all_data[i] = GetTensorData<T>(t);
  95. }
  96. }
  97. // Gets scale and zero point from a list of tensors
  98. inline void GetAllQuantizationParam(const TfLiteContext& context,
  99. const TfLiteIntArray& tensor_list,
  100. float scales[kMaxInputNum],
  101. int32 zero_points[kMaxInputNum]) {
  102. for (int i = 0; i < tensor_list.size; ++i) {
  103. const TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
  104. scales[i] = t->params.scale;
  105. zero_points[i] = t->params.zero_point;
  106. }
  107. }
  108. template <typename data_type>
  109. void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
  110. // Collect the shapes and data pointer of input tensors
  111. RuntimeShape inputs_shape[kMaxInputNum];
  112. const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
  113. const data_type* inputs_data[kMaxInputNum];
  114. GetAllTensorShapes(*context, *node->inputs, inputs_shape);
  115. GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
  116. GetAllTensorData(*context, *node->inputs, inputs_data);
  117. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  118. const TfLiteConcatenationParams* params =
  119. reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
  120. ConcatenationParams op_params;
  121. op_params.axis = CalculatePositiveAxis(params->axis, output);
  122. op_params.inputs_count = NumInputs(node);
  123. reference_ops::Concatenation(op_params, inputs_shape_ptr, inputs_data,
  124. GetTensorShape(output),
  125. GetTensorData<data_type>(output));
  126. }
  127. void EvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node) {
  128. // Collect the shapes and data pointer of input tensors
  129. RuntimeShape inputs_shape[kMaxInputNum];
  130. const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
  131. const uint8_t* inputs_data[kMaxInputNum];
  132. float inputs_scale[kMaxInputNum];
  133. int32 inputs_zero_point[kMaxInputNum];
  134. GetAllTensorShapes(*context, *node->inputs, inputs_shape);
  135. GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
  136. GetAllTensorData(*context, *node->inputs, inputs_data);
  137. GetAllQuantizationParam(*context, *node->inputs, inputs_scale,
  138. inputs_zero_point);
  139. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  140. const TfLiteConcatenationParams* params =
  141. reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
  142. ConcatenationParams op_params;
  143. op_params.axis = CalculatePositiveAxis(params->axis, output);
  144. op_params.inputs_count = NumInputs(node);
  145. op_params.input_zeropoint = inputs_zero_point;
  146. op_params.input_scale = inputs_scale;
  147. op_params.output_zeropoint = output->params.zero_point;
  148. op_params.output_scale = output->params.scale;
  149. reference_ops::ConcatenationWithScaling(op_params, inputs_shape_ptr,
  150. inputs_data, GetTensorShape(output),
  151. GetTensorData<uint8>(output));
  152. }
  153. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  154. TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
  155. switch (output_type) { // Already know in/outtypes are same.
  156. case kTfLiteFloat32:
  157. EvalUnquantized<float>(context, node);
  158. break;
  159. case kTfLiteInt32:
  160. EvalUnquantized<int32_t>(context, node);
  161. break;
  162. case kTfLiteUInt8:
  163. EvalQuantizedUInt8(context, node);
  164. break;
  165. case kTfLiteInt8:
  166. EvalUnquantized<int8_t>(context, node);
  167. break;
  168. case kTfLiteInt64:
  169. EvalUnquantized<int64_t>(context, node);
  170. break;
  171. default:
  172. TF_LITE_KERNEL_LOG(
  173. context, "Op Concatenation does not currently support Type '%s'.",
  174. TfLiteTypeGetName(output_type));
  175. return kTfLiteError;
  176. }
  177. return kTfLiteOk;
  178. }
  179. } // namespace concatenation
  180. TfLiteRegistration* Register_CONCATENATION() {
  181. static TfLiteRegistration r = {/*init=*/nullptr,
  182. /*free=*/nullptr,
  183. /*prepare=*/concatenation::Prepare,
  184. /*invoke=*/concatenation::Eval,
  185. /*profiling_string=*/nullptr,
  186. /*builtin_code=*/0,
  187. /*custom_name=*/nullptr,
  188. /*version=*/0};
  189. return &r;
  190. }
  191. } // namespace micro
  192. } // namespace ops
  193. } // namespace tflite