hard_swish.h 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
  13. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ACTIVATIONS_H_
  14. #include "ruy/profiler/instrumentation.h" // from @ruy
  15. #include "tensorflow/lite/kernels/internal/common.h"
  16. #include "tensorflow/lite/kernels/internal/types.h"
  17. namespace tflite {
  18. namespace reference_ops {
  19. inline int16_t SaturatingLeftShift(int16_t value, int amount) {
  20. int32_t result = static_cast<int32_t>(value) * (1 << amount);
  21. result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
  22. result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
  23. return result;
  24. }
  25. // Similar to ARM instruction SQDMULH.
  26. // Similar to gemmlowp::SaturatingRoundingDoublingHighMul except
  27. // rounding to zero instead of to nearest (SQRDMULH).
  28. inline std::int16_t SaturatingDoublingHighMul(std::int16_t a, std::int16_t b) {
  29. bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
  30. std::int32_t a_32(a);
  31. std::int32_t b_32(b);
  32. std::int32_t ab_32 = a_32 * b_32;
  33. std::int16_t ab_x2_high16 = static_cast<std::int16_t>((ab_32) / (1 << 15));
  34. return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
  35. }
  36. template <typename T>
  37. inline void HardSwish(const RuntimeShape& input_shape, const T* input_data,
  38. const RuntimeShape& output_shape, T* output_data) {
  39. ruy::profiler::ScopeLabel label("ReferenceHardSwish/Float");
  40. auto matching_size = MatchingFlatSize(input_shape, output_shape);
  41. const T* in_end = input_data + matching_size;
  42. for (; input_data < in_end; input_data++, output_data++) {
  43. const float in = *input_data;
  44. *output_data =
  45. in * std::min(static_cast<T>(6), std::max(static_cast<T>(0), in + 3)) /
  46. 6;
  47. }
  48. }
  49. template <typename T>
  50. inline void HardSwish(const HardSwishParams& params,
  51. const RuntimeShape& input_shape, const T* input_data,
  52. const RuntimeShape& output_shape, T* output_data) {
  53. ruy::profiler::ScopeLabel label("ReferenceHardSwish/Quantized");
  54. const int flat_size = MatchingFlatSize(input_shape, output_shape);
  55. for (int i = 0; i < flat_size; i++) {
  56. const int16_t input_value = input_data[i] - params.input_zero_point;
  57. // Left-shift as much as we can without overflow/saturation to put
  58. // significant bits in the high bits of our 16-bit fixedpoint values, so
  59. // that fixed-point approximate computations below are as accurate as
  60. // possible.
  61. const int16_t input_value_on_hires_input_scale = input_value * (1 << 7);
  62. // Compute the input value on essentially the output scale, just not
  63. // right-shifted yet. This is the value that we'll use in the (x >= +3)
  64. // case, and that in the general case we'll multiply against the "relu-ish"
  65. // fixed-point multiplier in [0, 1].
  66. const int16_t input_value_on_preshift_output_scale =
  67. gemmlowp::SaturatingRoundingDoublingHighMul(
  68. input_value_on_hires_input_scale,
  69. params.output_multiplier_fixedpoint_int16);
  70. // Now compute the "relu-ish multiplier". In the (-3 <= x <= +3) case, that
  71. // is just an affine rescaling of x from [-3, 3] to [0, 1]. In the general
  72. // case, it is just that plus saturation at the boundaries of [-3, 3].
  73. // First, we rescale from [-3, 3] to [-1, 1], saturating.
  74. // That is done by rescaling the input value with a fixed-point multiplier
  75. // (reluish_multiplier_fixedpoint) and bit-shift such that we represent
  76. // that input value on the scale where the real value 3.0f is represented
  77. // by the quantized value 32768. (+32768 is actually not representable as
  78. // int16_t, so this saturates at +32767, and that is seen empirically to be
  79. // a negligible contribution to numerical error/bias).
  80. //
  81. // This code is careful to correctly implement any magnitude of multiplier,
  82. // involving either a right shift or a left shift, with correct saturation
  83. // behavior in the left-shift case. This forces this code to be more
  84. // complicated, but is necessary for real applications: a partially
  85. // trained quantized MobileNet v3-small model that motivated this code
  86. // exhibits some large [min, max] range boundaries, of the order of
  87. // magnitude of 10 or 100 depending on layers.
  88. //
  89. // The next few lines are basically just an ordinary
  90. // MultiplyByQuantizedMultiplier, except that we are more careful here
  91. // about the fine details of saturation when left-shifting, because here
  92. // overflow in left-shift is a common case, not an anomaly as
  93. // MultiplyByQuantizedMultiplier assumes.
  94. int16_t reluish_value = input_value_on_hires_input_scale;
  95. // Shift left, saturating, as much as we can while ensuring that this
  96. // saturation will not contribute to the result. That is, left shift amount
  97. // reduced by 1.
  98. if (params.reluish_multiplier_exponent > 0) {
  99. reluish_value = SaturatingLeftShift(
  100. reluish_value, params.reluish_multiplier_exponent - 1);
  101. }
  102. // Apply the fixed-point multiplier, dividing the value by a divisor
  103. // ranging in [1, 2].
  104. reluish_value = gemmlowp::SaturatingRoundingDoublingHighMul(
  105. reluish_value, params.reluish_multiplier_fixedpoint_int16);
  106. // Apply the last bit of left-shift. Thus, in the left-shifting case, if
  107. // any saturation affects the result, it is happening here --- any
  108. // saturation having occurred above is overwritten here, not affecting the
  109. // result.
  110. if (params.reluish_multiplier_exponent > 0) {
  111. reluish_value = SaturatingLeftShift(reluish_value, 1);
  112. }
  113. // Shift right, in the right-shifting case.
  114. if (params.reluish_multiplier_exponent < 0) {
  115. reluish_value = gemmlowp::RoundingDivideByPOT(
  116. reluish_value, -params.reluish_multiplier_exponent);
  117. }
  118. // At this point we have rescaled the value into a 16bit fixedpoint
  119. // reluish_value in [-1, 1].
  120. // We now convert that to a 16bit fixedpoint value in [0, 1].
  121. reluish_value = (reluish_value + (1 << 15)) >> 1;
  122. // Use of SaturatingDoublingHighMul here is important to cancel the biases
  123. // from the above SaturatingRoundingDoublingHighMul.
  124. //
  125. // On a partially trained MobileNet-v3-small,
  126. //
  127. // | bias on | ImageNet
  128. // | quantized | Top-1
  129. // Operation used here | values | accuracy (50k)
  130. // --------------------------------------+------------+-----------
  131. // SaturatingDoublingHighMul | -0.0024 | 58.920
  132. // SaturatingRoundingDoublingHighMul | -0.0067 | 58.064
  133. //
  134. // In activations_test, this is covered by this testcase:
  135. // QuantizedActivationsOpTest.HardSwishBias
  136. //
  137. const int16_t preshift_output_value = SaturatingDoublingHighMul(
  138. reluish_value, input_value_on_preshift_output_scale);
  139. // We were so far operating on the pre-shift output scale. Now we finally
  140. // apply that output shift, arriving at the final output scale.
  141. int16_t output_value = gemmlowp::RoundingDivideByPOT(
  142. preshift_output_value, -params.output_multiplier_exponent);
  143. output_value += params.output_zero_point;
  144. output_value =
  145. std::min<int16_t>(output_value, std::numeric_limits<T>::max());
  146. output_value =
  147. std::max<int16_t>(output_value, std::numeric_limits<T>::min());
  148. output_data[i] = output_value;
  149. }
  150. }
  151. } // namespace reference_ops
  152. } // namespace tflite
  153. #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_