pad.cc 9.8 KB


  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/pad.h"
  13. #include <string.h>
  14. #include "tensorflow/lite/c/builtin_op_data.h"
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/kernels/internal/tensor.h"
  17. #include "tensorflow/lite/kernels/internal/types.h"
  18. #include "tensorflow/lite/kernels/kernel_util.h"
  19. #include "tensorflow/lite/kernels/op_macros.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace pad {
  25. namespace {
  26. struct OpData {
  27. PadParams params;
  28. int32_t output_zero_point;
  29. };
  30. } // namespace
  31. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  32. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  33. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  34. }
  35. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  36. TFLITE_DCHECK(node->user_data != nullptr);
  37. OpData* data = static_cast<OpData*>(node->user_data);
  38. TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
  39. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  40. const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
  41. const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
  42. const TfLiteTensor* constant_values =
  43. NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
  44. TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
  45. TF_LITE_ENSURE_EQ(context, input->type, output->type);
  46. // Current implementations rely on the inputs being <= 4D.
  47. TF_LITE_ENSURE(context, NumDimensions(input) <=
  48. reference_ops::PadKernelMaxDimensionCount());
  49. if (constant_values != nullptr) {
  50. TF_LITE_ENSURE_EQ(context, input->type, constant_values->type);
  51. // Ensure that constant_values is a scalar.
  52. TF_LITE_ENSURE_EQ(context, NumElements(constant_values), 1);
  53. }
  54. // There must be a pair of paddings for each output dimension.
  55. TF_LITE_ENSURE_EQ(context, GetTensorShape(paddings).FlatSize(),
  56. output->dims->size * 2);
  57. // On Micro, outputs must be properly sized by the converter.
  58. // NOTE: This data is only available because the paddings buffer is stored in
  59. // the flatbuffer:
  60. TF_LITE_ENSURE(context, IsConstantTensor(paddings));
  61. const int32_t* paddings_data = GetTensorData<int32_t>(paddings);
  62. for (int i = 0; i < output->dims->size; i++) {
  63. int output_dim = output->dims->data[i];
  64. int expected_dim =
  65. input->dims->data[i] + paddings_data[i * 2] + paddings_data[i * 2 + 1];
  66. TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
  67. }
  68. // Calculate OpData:
  69. data->params.resizing_category = ResizingCategory::kGenericResize;
  70. const int paddings_total = GetTensorShape(paddings).FlatSize();
  71. if (paddings_total == 8 && (paddings_data[0] == 0 && paddings_data[1] == 0) &&
  72. (paddings_data[6] == 0 && paddings_data[7] == 0)) {
  73. data->params.resizing_category = ResizingCategory::kImageStyle;
  74. }
  75. const int num_input_dimensions = NumDimensions(input);
  76. data->params.left_padding_count = num_input_dimensions;
  77. data->params.right_padding_count = num_input_dimensions;
  78. for (int idx = num_input_dimensions - 1; idx >= 0; --idx) {
  79. data->params.left_padding[idx] = paddings_data[idx * 2];
  80. data->params.right_padding[idx] = paddings_data[idx * 2 + 1];
  81. }
  82. if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
  83. if (constant_values == nullptr) {
  84. // Quantized Pad requires that 0 is represented in the quantized
  85. // range.
  86. if (input->type == kTfLiteUInt8) {
  87. TF_LITE_ENSURE(context, output->params.zero_point >=
  88. std::numeric_limits<uint8_t>::min());
  89. TF_LITE_ENSURE(context, output->params.zero_point <=
  90. std::numeric_limits<uint8_t>::max());
  91. } else {
  92. TF_LITE_ENSURE(context, output->params.zero_point >=
  93. std::numeric_limits<int8_t>::min());
  94. TF_LITE_ENSURE(context, output->params.zero_point <=
  95. std::numeric_limits<int8_t>::max());
  96. }
  97. } else {
  98. // Quantized Pad requires that 'constant_values' is represented in the
  99. // same quantized range as the input and output tensors.
  100. TF_LITE_ENSURE_EQ(context, output->params.zero_point,
  101. constant_values->params.zero_point);
  102. TF_LITE_ENSURE_EQ(context, static_cast<double>(output->params.scale),
  103. static_cast<double>(constant_values->params.scale));
  104. }
  105. data->output_zero_point = output->params.zero_point;
  106. }
  107. return kTfLiteOk;
  108. }
  109. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  110. TFLITE_DCHECK(node->user_data != nullptr);
  111. const OpData* data = static_cast<const OpData*>(node->user_data);
  112. const TfLiteEvalTensor* input =
  113. tflite::micro::GetEvalInput(context, node, /*index=*/0);
  114. const TfLiteEvalTensor* constant_values =
  115. NumInputs(node) == 3
  116. ? tflite::micro::GetEvalInput(context, node, /*index=*/2)
  117. : nullptr;
  118. TfLiteEvalTensor* output =
  119. tflite::micro::GetEvalOutput(context, node, /*index=*/0);
  120. switch (input->type) {
  121. case kTfLiteFloat32: {
  122. float pad_value =
  123. constant_values == nullptr
  124. ? 0.f
  125. : *tflite::micro::GetTensorData<float>(constant_values);
  126. if (data->params.resizing_category == ResizingCategory::kImageStyle) {
  127. reference_ops::PadImageStyle(
  128. data->params, tflite::micro::GetTensorShape(input),
  129. tflite::micro::GetTensorData<float>(input), &pad_value,
  130. tflite::micro::GetTensorShape(output),
  131. tflite::micro::GetTensorData<float>(output));
  132. } else {
  133. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  134. tflite::micro::GetTensorData<float>(input),
  135. &pad_value, tflite::micro::GetTensorShape(output),
  136. tflite::micro::GetTensorData<float>(output));
  137. }
  138. } break;
  139. case kTfLiteUInt8: {
  140. uint8_t pad_value;
  141. if (constant_values == nullptr) {
  142. pad_value = static_cast<uint8_t>(data->output_zero_point);
  143. } else {
  144. pad_value = *tflite::micro::GetTensorData<uint8_t>(constant_values);
  145. }
  146. if (data->params.resizing_category == ResizingCategory::kImageStyle) {
  147. reference_ops::PadImageStyle(
  148. data->params, tflite::micro::GetTensorShape(input),
  149. tflite::micro::GetTensorData<uint8_t>(input), &pad_value,
  150. tflite::micro::GetTensorShape(output),
  151. tflite::micro::GetTensorData<uint8_t>(output));
  152. } else {
  153. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  154. tflite::micro::GetTensorData<uint8_t>(input),
  155. &pad_value, tflite::micro::GetTensorShape(output),
  156. tflite::micro::GetTensorData<uint8_t>(output));
  157. }
  158. } break;
  159. case kTfLiteInt8: {
  160. int8_t pad_value;
  161. if (constant_values == nullptr) {
  162. pad_value = static_cast<uint8_t>(data->output_zero_point);
  163. } else {
  164. pad_value = *tflite::micro::GetTensorData<int8_t>(constant_values);
  165. }
  166. if (data->params.resizing_category == ResizingCategory::kImageStyle) {
  167. reference_ops::PadImageStyle(
  168. data->params, tflite::micro::GetTensorShape(input),
  169. tflite::micro::GetTensorData<int8_t>(input), &pad_value,
  170. tflite::micro::GetTensorShape(output),
  171. tflite::micro::GetTensorData<int8_t>(output));
  172. } else {
  173. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  174. tflite::micro::GetTensorData<int8_t>(input),
  175. &pad_value, tflite::micro::GetTensorShape(output),
  176. tflite::micro::GetTensorData<int8_t>(output));
  177. }
  178. } break;
  179. case kTfLiteInt32: {
  180. int32_t pad_value =
  181. constant_values == nullptr
  182. ? 0
  183. : *tflite::micro::GetTensorData<int32_t>(constant_values);
  184. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  185. tflite::micro::GetTensorData<int32_t>(input),
  186. &pad_value, tflite::micro::GetTensorShape(output),
  187. tflite::micro::GetTensorData<int32_t>(output));
  188. } break;
  189. default:
  190. TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
  191. TfLiteTypeGetName(input->type));
  192. return kTfLiteError;
  193. }
  194. #undef TF_LITE_PAD
  195. return kTfLiteOk;
  196. }
  197. } // namespace pad
  198. TfLiteRegistration Register_PAD() {
  199. return {/*init=*/pad::Init,
  200. /*free=*/nullptr,
  201. /*prepare=*/pad::Prepare,
  202. /*invoke=*/pad::Eval,
  203. /*profiling_string=*/nullptr,
  204. /*builtin_code=*/0,
  205. /*custom_name=*/nullptr,
  206. /*version=*/0};
  207. }
  208. // Also register Pad as PadV2.
  209. TfLiteRegistration Register_PADV2() {
  210. return {/*init=*/pad::Init,
  211. /*free=*/nullptr,
  212. /*prepare=*/pad::Prepare,
  213. /*invoke=*/pad::Eval,
  214. /*profiling_string=*/nullptr,
  215. /*builtin_code=*/0,
  216. /*custom_name=*/nullptr,
  217. /*version=*/0};
  218. }
  219. } // namespace micro
  220. } // namespace ops
  221. } // namespace tflite