tanh.cc 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/common.h"
  16. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  17. #include "tensorflow/lite/kernels/internal/reference/tanh.h"
  18. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  19. #include "tensorflow/lite/kernels/kernel_util.h"
  20. #include "tensorflow/lite/kernels/op_macros.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace activations {
  25. namespace {
  26. constexpr int kInputTensor = 0;
  27. constexpr int kOutputTensor = 0;
  28. struct OpData {
  29. int32_t input_zero_point;
  30. int32_t input_range_radius;
  31. int32_t input_multiplier;
  32. int input_left_shift;
  33. };
  34. TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
  35. OpData* data) {
  36. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  37. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  38. TF_LITE_ENSURE_EQ(context, input->type, output->type);
  39. if (input->type == kTfLiteInt8) {
  40. TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
  41. // The number if input integer bits is set to be consistent with the
  42. // required value in reference_integer_ops::Tanh
  43. static constexpr int kInputIntegerBits = 4;
  44. const double input_real_multiplier =
  45. static_cast<double>(input->params.scale) *
  46. static_cast<double>(1 << (31 - kInputIntegerBits));
  47. const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
  48. data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
  49. data->input_range_radius =
  50. CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
  51. }
  52. return kTfLiteOk;
  53. }
  54. } // namespace
  55. TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
  56. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  57. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  58. OpData data;
  59. CalculateArithmeticOpData(context, node, &data);
  60. if (input->type == kTfLiteFloat32) {
  61. switch (output->type) {
  62. case kTfLiteFloat32: {
  63. reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
  64. GetTensorShape(output),
  65. GetTensorData<float>(output));
  66. return kTfLiteOk;
  67. }
  68. default:
  69. TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
  70. TfLiteTypeGetName(input->type),
  71. TfLiteTypeGetName(output->type));
  72. return kTfLiteError;
  73. }
  74. } else if (input->type == kTfLiteInt8) {
  75. switch (output->type) {
  76. case kTfLiteInt8: {
  77. reference_integer_ops::Tanh(
  78. input->params.zero_point, data.input_range_radius,
  79. data.input_multiplier, data.input_left_shift,
  80. NumElements(input->dims), GetTensorData<int8_t>(input),
  81. GetTensorData<int8_t>(output));
  82. return kTfLiteOk;
  83. }
  84. default:
  85. TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
  86. TfLiteTypeGetName(input->type),
  87. TfLiteTypeGetName(output->type));
  88. return kTfLiteError;
  89. }
  90. } else {
  91. TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
  92. TfLiteTypeGetName(input->type),
  93. TfLiteTypeGetName(output->type));
  94. return kTfLiteError;
  95. }
  96. return kTfLiteOk;
  97. }
  98. } // namespace activations
  99. TfLiteRegistration* Register_TANH() {
  100. static TfLiteRegistration r = {/*init=*/nullptr,
  101. /*free=*/nullptr,
  102. /*prepare=*/nullptr,
  103. /*invoke=*/activations::TanhEval,
  104. /*profiling_string=*/nullptr,
  105. /*builtin_code=*/0,
  106. /*custom_name=*/nullptr,
  107. /*version=*/0};
  108. return &r;
  109. }
  110. } // namespace micro
  111. } // namespace ops
  112. } // namespace tflite