pad.cc 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/pad.h"
  13. #include <string.h>
  14. #include "tensorflow/lite/kernels/internal/types.h"
  15. #ifdef MEMORY_SANITIZER
  16. #include <sanitizer/msan_interface.h>
  17. #else
  18. #define __msan_check_mem_is_initialized(ptr, size)
  19. #endif
  20. #include "tensorflow/lite/c/builtin_op_data.h"
  21. #include "tensorflow/lite/c/common.h"
  22. #include "tensorflow/lite/kernels/internal/tensor.h"
  23. #include "tensorflow/lite/kernels/kernel_util.h"
  24. #include "tensorflow/lite/kernels/op_macros.h"
  25. namespace tflite {
  26. namespace ops {
  27. namespace micro {
  28. namespace pad {
  29. struct PadContext {
  30. PadContext(TfLiteContext* context, TfLiteNode* node) {
  31. input = GetInput(context, node, 0);
  32. paddings = GetInput(context, node, 1);
  33. constant_values = nullptr;
  34. if (NumInputs(node) == 3) {
  35. constant_values = GetOptionalInputTensor(context, node, 2);
  36. } else {
  37. constant_values = nullptr;
  38. }
  39. output = GetOutput(context, node, 0);
  40. dims = NumDimensions(input);
  41. resizing_category = ResizingCategory::kGenericResize;
  42. const int paddings_total = GetTensorShape(paddings).FlatSize();
  43. const int32* paddings_data = GetTensorData<int32>(paddings);
  44. // Paddings will be a n,2 array, and we need to detect 4D arrays with the
  45. // pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
  46. if (IsConstantTensor(paddings) && paddings_total == 8 &&
  47. (paddings_data[0] == 0 && paddings_data[1] == 0) &&
  48. (paddings_data[6] == 0 && paddings_data[7] == 0)) {
  49. resizing_category = ResizingCategory::kImageStyle;
  50. }
  51. }
  52. const TfLiteTensor* constant_values;
  53. const TfLiteTensor* input;
  54. const TfLiteTensor* paddings;
  55. TfLiteTensor* output;
  56. int dims;
  57. ResizingCategory resizing_category;
  58. };
  59. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  60. TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
  61. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  62. PadContext op_context(context, node);
  63. TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
  64. if (op_context.constant_values != nullptr) {
  65. TF_LITE_ENSURE_EQ(context, op_context.input->type,
  66. op_context.constant_values->type);
  67. }
  68. // There must be a pair of paddings for each output dimension.
  69. TF_LITE_ENSURE_EQ(context, GetTensorShape(op_context.paddings).FlatSize(),
  70. op_context.output->dims->size * 2);
  71. // On Micro, outputs must be properly sized by the converter.
  72. const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
  73. for (int i = 0; i < op_context.output->dims->size; i++) {
  74. int output_dim = op_context.output->dims->data[i];
  75. int expected_dim = op_context.input->dims->data[i] + paddings_data[i * 2] +
  76. paddings_data[i * 2 + 1];
  77. TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
  78. }
  79. // Current implementations rely on the inputs being <= 4D.
  80. TF_LITE_ENSURE(
  81. context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
  82. TF_LITE_ENSURE(context, IsConstantTensor(op_context.paddings));
  83. return kTfLiteOk;
  84. }
  85. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  86. PadContext op_context(context, node);
  87. if (op_context.constant_values != nullptr) {
  88. // Ensure that constant_values is a scalar.
  89. TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
  90. }
  91. // Create before and after padding arrays that are accepted by the kernel.
  92. const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
  93. tflite::PadParams op_params;
  94. memset(&op_params, 0, sizeof(PadParams));
  95. op_params.left_padding_count = op_context.dims;
  96. op_params.right_padding_count = op_context.dims;
  97. for (int idx = op_context.dims - 1; idx >= 0; --idx) {
  98. op_params.left_padding[idx] = paddings_data[idx * 2];
  99. op_params.right_padding[idx] = paddings_data[idx * 2 + 1];
  100. }
  101. #define TF_LITE_PAD(type, op_name, scalar, pad_value) \
  102. const scalar pad_value_copy = pad_value; \
  103. \
  104. type::op_name(op_params, GetTensorShape(op_context.input), \
  105. GetTensorData<scalar>(op_context.input), &pad_value_copy, \
  106. GetTensorShape(op_context.output), \
  107. GetTensorData<scalar>(op_context.output))
  108. switch (op_context.input->type) {
  109. case kTfLiteFloat32: {
  110. float pad_value = op_context.constant_values == nullptr
  111. ? 0.f
  112. : *GetTensorData<float>(op_context.constant_values);
  113. if (op_context.resizing_category == ResizingCategory::kImageStyle) {
  114. TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
  115. } else {
  116. TF_LITE_PAD(reference_ops, Pad, float, pad_value);
  117. }
  118. } break;
  119. case kTfLiteUInt8: {
  120. uint8_t pad_value;
  121. if (op_context.constant_values == nullptr) {
  122. // Quantized Pad requires that 0 is represented in the quantized
  123. // range.
  124. TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
  125. std::numeric_limits<uint8_t>::min());
  126. TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
  127. std::numeric_limits<uint8_t>::max());
  128. pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
  129. } else {
  130. // Quantized Pad requires that 'constant_values' is represented in the
  131. // same quantized range as the input and output tensors.
  132. TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
  133. op_context.constant_values->params.zero_point);
  134. TF_LITE_ENSURE_EQ(
  135. context, static_cast<double>(op_context.output->params.scale),
  136. static_cast<double>(op_context.constant_values->params.scale));
  137. pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
  138. }
  139. if (op_context.resizing_category == ResizingCategory::kImageStyle) {
  140. TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
  141. } else {
  142. TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
  143. }
  144. } break;
  145. case kTfLiteInt8: {
  146. int8_t pad_value;
  147. if (op_context.constant_values == nullptr) {
  148. // Quantized Pad requires that 0 is represented in the quantized
  149. // range.
  150. TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
  151. std::numeric_limits<int8_t>::min());
  152. TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
  153. std::numeric_limits<int8_t>::max());
  154. pad_value = static_cast<int8_t>(op_context.output->params.zero_point);
  155. } else {
  156. // Quantized Pad requires that 'constant_values' is represented in the
  157. // same quantized range as the input and output tensors.
  158. TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
  159. op_context.constant_values->params.zero_point);
  160. TF_LITE_ENSURE(context, op_context.output->params.scale ==
  161. op_context.constant_values->params.scale);
  162. pad_value = *GetTensorData<int8_t>(op_context.constant_values);
  163. }
  164. if (op_context.resizing_category == ResizingCategory::kImageStyle) {
  165. TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value);
  166. } else {
  167. TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value);
  168. }
  169. } break;
  170. case kTfLiteInt32: {
  171. int32_t pad_value =
  172. op_context.constant_values == nullptr
  173. ? 0
  174. : *GetTensorData<int32_t>(op_context.constant_values);
  175. TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
  176. } break;
  177. default:
  178. TF_LITE_KERNEL_LOG(context, "Type %s not currently supported by Pad.",
  179. TfLiteTypeGetName(op_context.input->type));
  180. return kTfLiteError;
  181. }
  182. #undef TF_LITE_PAD
  183. return kTfLiteOk;
  184. }
  185. } // namespace pad
  186. TfLiteRegistration* Register_PAD() {
  187. static TfLiteRegistration r = {/*init=*/nullptr,
  188. /*free=*/nullptr,
  189. /*prepare=*/pad::Prepare,
  190. /*invoke=*/pad::Eval,
  191. /*profiling_string=*/nullptr,
  192. /*builtin_code=*/0,
  193. /*custom_name=*/nullptr,
  194. /*version=*/0};
  195. return &r;
  196. }
  197. // Also register Pad as PadV2.
  198. TfLiteRegistration* Register_PADV2() {
  199. static TfLiteRegistration r = {/*init=*/nullptr,
  200. /*free=*/nullptr,
  201. /*prepare=*/pad::Prepare,
  202. /*invoke=*/pad::Eval,
  203. /*profiling_string=*/nullptr,
  204. /*builtin_code=*/0,
  205. /*custom_name=*/nullptr,
  206. /*version=*/0};
  207. return &r;
  208. }
  209. } // namespace micro
  210. } // namespace ops
  211. } // namespace tflite