runtime_op_utility.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /* Copyright 2019-2020 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #pragma once
  16. #include "../datatypes.h"
  17. namespace nncase
  18. {
  19. namespace runtime
  20. {
  21. inline size_t get_bytes(datatype_t type)
  22. {
  23. size_t element_size;
  24. switch (type)
  25. {
  26. case dt_float32:
  27. element_size = 4;
  28. break;
  29. case dt_uint8:
  30. element_size = 1;
  31. break;
  32. default:
  33. NNCASE_THROW(std::runtime_error, "Not supported data type");
  34. }
  35. return element_size;
  36. }
  37. template <int32_t Bits, class T>
  38. uint8_t count_leading_zeros(T value)
  39. {
  40. uint8_t num_zeroes = 0;
  41. for (int32_t i = Bits - 1; i >= 0; i--)
  42. {
  43. if ((value & (1ULL << i)) == 0)
  44. ++num_zeroes;
  45. else
  46. break;
  47. }
  48. return num_zeroes;
  49. }
  50. template <class T = uint64_t>
  51. inline T bit_mask(uint8_t shift)
  52. {
  53. return (T(1) << shift) - 1;
  54. }
  55. template <class T, bool Banker = false>
  56. T carry_shift(T value, uint8_t shift)
  57. {
  58. if (shift > 0)
  59. {
  60. if constexpr (Banker)
  61. {
  62. T result;
  63. // Sign | Int (T - shift - 1 bits) | Frac (shift bits)
  64. // S IIII FFF
  65. auto integral = value >> shift;
  66. auto fractional = value & bit_mask(shift);
  67. auto sign = value < 0 ? -1 : 1;
  68. auto half = 1 << (shift - 1);
  69. // frac < 0.5
  70. if (fractional < half)
  71. {
  72. return integral;
  73. }
  74. // frac > 0.5
  75. else if (fractional > half)
  76. {
  77. return integral + sign;
  78. }
  79. // frac == 0.5
  80. else
  81. {
  82. // odd
  83. if (integral & 1)
  84. return integral + sign;
  85. // even
  86. else
  87. return integral;
  88. }
  89. return result;
  90. }
  91. else
  92. {
  93. value += T(1) << (shift - 1);
  94. value >>= shift;
  95. }
  96. }
  97. else if (shift < 0)
  98. {
  99. value = value << (-shift);
  100. }
  101. return value;
  102. }
  103. template <bool Banker = false>
  104. inline int32_t mul_and_carry_shift(int32_t value, int32_t mul, uint8_t shift)
  105. {
  106. return (int32_t)carry_shift<int64_t, Banker>((int64_t)value * mul, shift);
  107. }
  108. template <uint8_t Bits>
  109. inline int32_t clamp(int32_t value)
  110. {
  111. auto min = std::numeric_limits<int32_t>::lowest() >> (32 - Bits);
  112. auto max = std::numeric_limits<int32_t>::max() >> (32 - Bits);
  113. return std::clamp(value, min, max);
  114. }
  115. template <class T>
  116. struct to_datatype
  117. {
  118. };
  119. template <>
  120. struct to_datatype<float>
  121. {
  122. static constexpr datatype_t type = dt_float32;
  123. };
  124. template <>
  125. struct to_datatype<uint8_t>
  126. {
  127. static constexpr datatype_t type = dt_uint8;
  128. };
  129. template <class T>
  130. inline constexpr datatype_t to_datatype_v = to_datatype<T>::type;
  131. }
  132. }