fixedpoint.h 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // fixedpoint.h: fixed-point arithmetic, with basic operations and
  15. // a few math functions such as tanh.
  16. #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
  17. #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
  18. #include <algorithm>
  19. #include <cassert>
  20. #include <cmath>
  21. #include <cstdint>
  22. #include <limits>
  23. #include "detect_platform.h"
  24. namespace gemmlowp {
  25. // Part 1: Low-level integer-arithmetic primitives.
  26. // The implementations here are generic implementations valid for
  27. // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
  28. // (e.g. NEON int32x4_t) may be supported by providing
  29. // specializations for them in separate files.
  30. //
  31. // The purpose of these primitives is two-fold:
  32. // - They will be used to implement higher-level fixed-point
  33. // abstractions, namely the FixedPoint class and its arithmetic
  34. // operators.
  35. // - They will be directly used to implement some more involved
  36. // fixed-point computations, e.g. the fixed-point implementation
  37. // of math functions such as tanh.
  38. // Some compile-time traits around raw types to handle SIMD aspects:
  39. // number of lanes, underlying scalar type.
  40. template <typename tIntegerType>
  41. struct FixedPointRawTypeTraits {};
  42. template <>
  43. struct FixedPointRawTypeTraits<std::int32_t> {
  44. typedef std::int32_t ScalarRawType;
  45. static constexpr int kLanes = 1;
  46. };
  47. template <>
  48. struct FixedPointRawTypeTraits<std::int16_t> {
  49. typedef std::int16_t ScalarRawType;
  50. static constexpr int kLanes = 1;
  51. };
  52. // Returns a SIMD value duplicating a scalar value across all lanes.
  53. template <typename tRawType>
  54. tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
  55. return x;
  56. }
  57. // Plain bit-wise AND
  58. template <typename tIntegerType>
  59. tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
  60. return a & b;
  61. }
  62. // Plain bit-wise OR
  63. template <typename tIntegerType>
  64. tIntegerType BitOr(tIntegerType a, tIntegerType b) {
  65. return a | b;
  66. }
  67. // Plain bit-wise XOR
  68. template <typename tIntegerType>
  69. tIntegerType BitXor(tIntegerType a, tIntegerType b) {
  70. return a ^ b;
  71. }
  72. // Plain bit-wise NOT
  73. template <typename tIntegerType>
  74. tIntegerType BitNot(tIntegerType a) {
  75. return ~a;
  76. }
  77. // Integer addition. Not saturating. Overflow is undefined behavior.
  78. template <typename tIntegerType>
  79. tIntegerType Add(tIntegerType a, tIntegerType b) {
  80. return a + b;
  81. }
  82. // Integer multiplication. Not saturating. Overflow is undefined behavior.
  83. template <typename tIntegerType>
  84. tIntegerType Mul(tIntegerType a, tIntegerType b) {
  85. return a * b;
  86. }
  87. // Integer subtraction. Not saturating. Overflow is undefined behavior.
  88. template <typename tIntegerType>
  89. tIntegerType Sub(tIntegerType a, tIntegerType b) {
  90. return a - b;
  91. }
  92. // Integer unary negative. Not saturating. Overflow is undefined behavior.
  93. template <typename tIntegerType>
  94. tIntegerType Neg(tIntegerType a) {
  95. return -a;
  96. }
  97. // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
  98. // Negative values are OK. In case of overflow, no Undefined
  99. // Behavior, but the results are implementation-defined (in practice,
  100. // they currently are saturated, but we make no commitment to that). The idea
  101. // is that the caller will want to implement the overflowing cases with
  102. // saturation with compare-and-mask, so we don't care about the results
  103. // in the overflow case, we just want to avoid undefined behavior.
  104. //
  105. // tIntegerType may be int32 or any narrower signed type.
  106. template <typename tIntegerType, typename OffsetType>
  107. tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
  108. const std::int64_t wide_a = static_cast<std::int64_t>(a);
  109. const std::int64_t wide_shifted = wide_a * (1 << offset);
  110. const auto min = std::numeric_limits<tIntegerType>::min();
  111. const auto max = std::numeric_limits<tIntegerType>::max();
  112. return wide_shifted < min
  113. ? min
  114. : wide_shifted > max ? max
  115. : static_cast<tIntegerType>(wide_shifted);
  116. }
  117. // Integer arithmetic right-shift. Not rounding.
  118. // Relying on implementation-defined, but in-practice-consistent,
  119. // C++ compiler behavior.
  120. template <typename tIntegerType>
  121. tIntegerType ShiftRight(tIntegerType a, int offset) {
  122. return a >> offset;
  123. }
  124. // Each bit of the result is set to the corresponding bit of either then_val or
  125. // else_val depending on whether the corresponding bit of if_mask is set.
  126. // Equivalent to the VBSL instruction in ARM NEON.
  127. template <typename tIntegerType>
  128. tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
  129. tIntegerType else_val) {
  130. return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
  131. }
  132. // For each input scalar, the corresponding bits of the result are set if the
  133. // input scalar is non-zero.
  134. template <typename tIntegerType>
  135. tIntegerType MaskIfNonZero(tIntegerType a) {
  136. static constexpr tIntegerType zero = 0;
  137. return a ? BitNot(zero) : zero;
  138. }
  139. // For each input scalar, the corresponding bits of the result are set if the
  140. // input scalar is zero.
  141. template <typename tIntegerType>
  142. tIntegerType MaskIfZero(tIntegerType a) {
  143. return MaskIfNonZero<tIntegerType>(!a);
  144. }
  145. // For each pair of input scalars, the corresponding bits of the result are
  146. // set if the input scalars are equal.
  147. template <typename tIntegerType>
  148. tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
  149. return MaskIfNonZero<tIntegerType>(a == b);
  150. }
  151. // For each pair of input scalars, the corresponding bits of the result are
  152. // set if the input scalars are not equal.
  153. template <typename tIntegerType>
  154. tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
  155. return MaskIfNonZero<tIntegerType>(a != b);
  156. }
  157. // For each pair of input scalars, the corresponding bits of the result are
  158. // set if the input scalars a, b satisfy a > b.
  159. template <typename tIntegerType>
  160. tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
  161. return MaskIfNonZero<tIntegerType>(a > b);
  162. }
  163. // For each pair of input scalars, the corresponding bits of the result are
  164. // set if the input scalars a, b satisfy a >= b.
  165. template <typename tIntegerType>
  166. tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
  167. return MaskIfNonZero<tIntegerType>(a >= b);
  168. }
  169. // For each pair of input scalars, the corresponding bits of the result are
  170. // set if the input scalars a, b satisfy a < b.
  171. template <typename tIntegerType>
  172. tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
  173. return MaskIfNonZero<tIntegerType>(a < b);
  174. }
  175. // For each pair of input scalars, the corresponding bits of the result are
  176. // set if the input scalars a, b satisfy a <= b.
  177. template <typename tIntegerType>
  178. tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
  179. return MaskIfNonZero<tIntegerType>(a <= b);
  180. }
  181. // Returns true if all of the input scalars are nonzero.
  182. // This function may currently assume that each of the input scalars has either
  183. // all or none of its bits set. Otherwise, its behavior is currently undefined.
  184. template <typename tIntegerType>
  185. bool All(tIntegerType a) {
  186. return a;
  187. }
  188. // Returns true if any of the input scalars are nonzero.
  189. // This function may currently assume that each of the input scalars has either
  190. // all or none of its bits set. Otherwise, its behavior is currently undefined.
  191. template <typename tIntegerType>
  192. bool Any(tIntegerType a) {
  193. return a;
  194. }
  195. // Returns (a+b)/2, rounded to the nearest integer.
  196. // Equivalent to VRHADD in the ARM NEON instruction set.
  197. template <typename IntegerType>
  198. IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
  199. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  200. (void)b;
  201. return a;
  202. }
  203. template <>
  204. inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
  205. std::int64_t a64 = a;
  206. std::int64_t b64 = b;
  207. std::int64_t sum = a64 + b64;
  208. std::int64_t sign = sum >= 0 ? 1 : -1;
  209. return static_cast<std::int32_t>((sum + sign) / 2);
  210. }
  211. template <>
  212. inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
  213. std::int32_t a32 = a;
  214. std::int32_t b32 = b;
  215. std::int32_t sum = a32 + b32;
  216. std::int32_t sign = sum >= 0 ? 1 : -1;
  217. return static_cast<std::int16_t>((sum + sign) / 2);
  218. }
  219. template <typename IntegerType>
  220. IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
  221. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  222. (void)b;
  223. return a;
  224. }
  225. // So far this is only needed for int16.
  226. template <>
  227. inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
  228. std::int32_t a32 = a;
  229. std::int32_t b32 = b;
  230. std::int32_t sum = a32 + b32;
  231. return static_cast<std::int16_t>(
  232. std::min(static_cast<std::int32_t>(32767),
  233. std::max(static_cast<std::int32_t>(-32768), sum)));
  234. }
  235. template <>
  236. inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
  237. std::int16_t a16 = a;
  238. std::int16_t b16 = b;
  239. std::int16_t sum = a16 + b16;
  240. return static_cast<std::int8_t>(std::min(
  241. static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
  242. std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
  243. }
  244. // Returns a+b, saturating if the integers are 16bit or narrower,
  245. // otherwise just a plain addition.
  246. template <typename IntegerType, bool Is16Bit>
  247. struct AddSaturatingIf16BitImpl {
  248. static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
  249. };
  250. template <typename IntegerType>
  251. struct AddSaturatingIf16BitImpl<IntegerType, true> {
  252. static IntegerType Run(IntegerType a, IntegerType b) {
  253. return SaturatingAdd(a, b);
  254. }
  255. };
  256. template <typename IntegerType>
  257. IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
  258. using ScalarType =
  259. typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
  260. return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
  261. b);
  262. }
  263. // Returns the integer that represents the product of two fixed-point
  264. // numbers, interpreting all integers as fixed-point values in the
  265. // interval [-1, 1), rounding to the nearest value, and saturating
  266. // -1 * -1 to the maximum value (since 1 is not in the half-open
  267. // interval [-1, 1)).
  268. //
  269. // [The explanation below specializes to std::int32_t for example purpose.]
  270. //
  271. // The mapping between IntegerType and the interval [-1, 1) is unique and
  272. // implied by IntegerType, which is assumed to be signed. For example,
  273. // for IntegerType==std::int32_t, the mapping is
  274. // real_value = integer_value / 2^31.
  275. // So in this case, and leaving aside rounding and saturating, this
  276. // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
  277. // (a * b) / 2^31.
  278. //
  279. // The 'doubling' part in the name of this function comes from the fact that
  280. // this operation is very close to a "multiply-high" operation, keeping only
  281. // the top half bits, except that that would be effectively computing
  282. // (a * b) / 2^32,
  283. // so here we are computing 2x that, since
  284. // 1/2^31 = 2 * 1/2^32.
  285. // The idea is to use all of the available 32 bits in the destination int32
  286. // value.
  287. //
  288. // [End of the explanation specializing to int32.]
  289. //
  290. // This is equivalent to the VQRDMULH instruction in ARM NEON.
  291. template <typename IntegerType>
  292. IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
  293. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  294. (void)b;
  295. return a;
  296. }
  297. // This function implements the same computation as the ARMv7 NEON VQRDMULH
  298. // instruction.
  299. template <>
  300. inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
  301. std::int32_t b) {
  302. bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
  303. std::int64_t a_64(a);
  304. std::int64_t b_64(b);
  305. std::int64_t ab_64 = a_64 * b_64;
  306. std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
  307. std::int32_t ab_x2_high32 =
  308. static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
  309. return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
  310. }
  311. template <>
  312. inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
  313. std::int16_t b) {
  314. bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
  315. std::int32_t a_32(a);
  316. std::int32_t b_32(b);
  317. std::int32_t ab_32 = a_32 * b_32;
  318. std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
  319. std::int16_t ab_x2_high16 =
  320. static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
  321. return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
  322. }
  323. // Correctly-rounded-to-nearest division by a power-of-two.
  324. // Also known as a rounding arithmetic right shift.
  325. template <typename IntegerType, typename ExponentType>
  326. inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
  327. assert(exponent >= 0);
  328. assert(exponent <= 31);
  329. const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
  330. const IntegerType zero = Dup<IntegerType>(0);
  331. const IntegerType one = Dup<IntegerType>(1);
  332. const IntegerType remainder = BitAnd(x, mask);
  333. const IntegerType threshold =
  334. Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
  335. return Add(ShiftRight(x, exponent),
  336. BitAnd(MaskIfGreaterThan(remainder, threshold), one));
  337. }
  338. // Returns the product of a run-time integer value by a compile-time power
  339. // of two, with either a positive exponent (equivalent to an arithmetic
  340. // left shift, saturating) or a negative exponent (equivalent to an arithmetic
  341. // right shift, rounding to nearest).
  342. template <int Exponent, typename IntegerType,
  343. int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
  344. struct ImplSaturatingRoundingMultiplyByPOT {};
  345. template <int Exponent, typename IntegerType>
  346. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
  347. static IntegerType eval(IntegerType x) { return x; }
  348. };
  349. template <int Exponent, typename IntegerType>
  350. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
  351. static IntegerType eval(IntegerType x) {
  352. using ScalarIntegerType =
  353. typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
  354. const IntegerType min =
  355. Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
  356. const IntegerType max =
  357. Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
  358. const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
  359. const std::int32_t threshold =
  360. ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
  361. const IntegerType positive_mask =
  362. MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
  363. const IntegerType negative_mask =
  364. MaskIfLessThan(x, Dup<IntegerType>(-threshold));
  365. IntegerType result = ShiftLeft(x, Exponent);
  366. result = SelectUsingMask(positive_mask, max, result);
  367. result = SelectUsingMask(negative_mask, min, result);
  368. return result;
  369. }
  370. };
  371. template <int Exponent, typename IntegerType>
  372. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
  373. static IntegerType eval(IntegerType x) {
  374. return RoundingDivideByPOT<IntegerType>(x, -Exponent);
  375. }
  376. };
  377. template <int Exponent, typename IntegerType>
  378. IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
  379. return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
  380. }
  381. // Part 2: the FixedPoint class.
  382. // A FixedPoint object represents a fixed-point value stored in the underlying
  383. // integer type tRawType, if tRawType is a plain scalar integer type.
  384. // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
  385. // case a FixedPoint object represents a corresponding SIMD vector of fixed
  386. // point values.
  387. //
  388. // tIntegerBits describes the range of the fixed-point format: if
  389. // tIntegerBits == m then the range of representable values is the half-open
  390. // interval [-2^m; 2^m) where the open boundary on the right side means that
  391. // 2^m is not representable (how close the maximum representable value is to
  392. // it, depends on bit-depth of tRawType).
  393. //
  394. // In "Q format notation",
  395. // https://en.wikipedia.org/wiki/Q_(number_format)
  396. // we are describing the format
  397. // Qm.n
  398. // where
  399. // m = tIntegerBits
  400. // and
  401. // n = NumberOfBits(tRawType) - (m + 1)
  402. // Note that the (m + 1) in the above line is because we adopt the convention
  403. // that we count the integer bits exclusively of the sign bit; so (m + 1) is
  404. // the total number of integer bits inclusive of the sign bit.
  405. //
  406. // Accordingly, the number of integral representable values in our range
  407. // [-2^m ; 2^m)
  408. // is equal to 2^(m+1).
  409. template <typename tRawType, int tIntegerBits>
  410. class FixedPoint {
  411. public:
  412. typedef tRawType RawType;
  413. typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
  414. typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
  415. static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
  416. static constexpr int kIntegerBits = tIntegerBits;
  417. static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
  418. static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
  419. "bad IntegerBits");
  420. typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
  421. static const ScalarRawType ScalarRawMin() {
  422. return std::numeric_limits<ScalarRawType>::min();
  423. }
  424. static const ScalarRawType ScalarRawMax() {
  425. return std::numeric_limits<ScalarRawType>::max();
  426. }
  427. static const ScalarRawType RawMin() {
  428. return VectorFromScalar(ScalarRawMin());
  429. }
  430. static const ScalarRawType RawMax() {
  431. return VectorFromScalar(ScalarRawMax());
  432. }
  433. static FixedPoint FromRaw(RawType x) {
  434. FixedPoint retval;
  435. retval.raw() = x;
  436. return retval;
  437. }
  438. static FixedPoint FromScalarRaw(ScalarRawType x) {
  439. FixedPoint retval;
  440. retval.raw() = Dup<RawType>(x);
  441. return retval;
  442. }
  443. static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
  444. return FromScalarRaw(x.raw());
  445. }
  446. template <int Exponent>
  447. static FixedPoint ConstantPOT() {
  448. static constexpr int kOffset = kFractionalBits + Exponent;
  449. static_assert(
  450. kOffset < 31,
  451. "Constant not exactly representable in this fixed-point format");
  452. return FromScalarRaw(ScalarRawType(1) << kOffset);
  453. }
  454. static FixedPoint Zero() { return FromScalarRaw(0); }
  455. static FixedPoint One() {
  456. return FromScalarRaw(
  457. kIntegerBits == 0
  458. ? ScalarRawMax()
  459. : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
  460. }
  461. static FixedPoint FromDouble(double x) {
  462. const double min_bound = static_cast<double>(ScalarRawMin());
  463. const double max_bound = static_cast<double>(ScalarRawMax());
  464. return FromScalarRaw(static_cast<ScalarRawType>(std::min(
  465. std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
  466. min_bound),
  467. max_bound)));
  468. }
  469. RawType raw() const { return i_; }
  470. RawType& raw() { return i_; }
  471. private:
  472. RawType i_;
  473. };
  474. // Part 3: implementation of arithmetic operators for the
  475. // FixedPoint class, and a few related functions.
  476. // A FixedPoint multiplication is just a
  477. // SaturatingRoundingDoublingHighMul operation on the underlying
  478. // raw integer values. The IntegerBits simply add up, as is obvious
  479. // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
  480. template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
  481. FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
  482. FixedPoint<tRawType, tIntegerBits_a> a,
  483. FixedPoint<tRawType, tIntegerBits_b> b) {
  484. FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
  485. c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
  486. return c;
  487. }
  488. // Tweaking IntegerBits gives exact multiplication by a power of two.
  489. template <int tExponent, typename tRawType, int tIntegerBits>
  490. FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
  491. FixedPoint<tRawType, tIntegerBits> a) {
  492. FixedPoint<tRawType, tExponent + tIntegerBits> c;
  493. c.raw() = a.raw();
  494. return c;
  495. }
  496. // If we want to leave IntegerBits fixed, then multiplication
  497. // by a power of two has to be saturating/rounding, not exact anymore.
  498. template <int tExponent, typename tRawType, int tIntegerBits>
  499. FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
  500. FixedPoint<tRawType, tIntegerBits> a) {
  501. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  502. SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
  503. }
  504. // Generic arithmetic operators.
  505. #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \
  506. template <typename tRawType, int tIntegerBits> \
  507. FixedPoint<tRawType, tIntegerBits> FuncName( \
  508. FixedPoint<tRawType, tIntegerBits> a) { \
  509. return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
  510. }
  511. #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
  512. template <typename tRawType, int tIntegerBits> \
  513. FixedPoint<tRawType, tIntegerBits> FuncName( \
  514. FixedPoint<tRawType, tIntegerBits> a, \
  515. FixedPoint<tRawType, tIntegerBits> b) { \
  516. return FixedPoint<tRawType, tIntegerBits>::FromRaw( \
  517. ImplFuncName(a.raw(), b.raw())); \
  518. }
  519. MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
  520. MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
  521. MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
  522. MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
  523. MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
  524. MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
  525. MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
  526. MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
  527. #undef MAKE_FIXEDPOINT_UNARY_FUNC
  528. #undef MAKE_FIXEDPOINT_BINARY_FUNC
  529. #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \
  530. template <typename tRawType, int tIntegerBits> \
  531. tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
  532. return FuncName(a.raw()); \
  533. }
  534. #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
  535. template <typename tRawType, int tIntegerBits> \
  536. tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \
  537. FixedPoint<tRawType, tIntegerBits> b) { \
  538. return FuncName(a.raw(), b.raw()); \
  539. }
  540. MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
  541. MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
  542. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
  543. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
  544. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
  545. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
  546. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
  547. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
  548. #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
  549. #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
  550. template <typename tRawType, int tIntegerBits>
  551. FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
  552. tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
  553. FixedPoint<tRawType, tIntegerBits> else_val) {
  554. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  555. SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
  556. }
  557. template <typename tRawType, int tIntegerBits>
  558. bool operator==(FixedPoint<tRawType, tIntegerBits> a,
  559. FixedPoint<tRawType, tIntegerBits> b) {
  560. return All(MaskIfEqual(a.raw(), b.raw()));
  561. }
  562. template <typename tRawType, int tIntegerBits>
  563. bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
  564. FixedPoint<tRawType, tIntegerBits> b) {
  565. return !(a == b);
  566. }
  567. template <typename tRawType, int tIntegerBits>
  568. FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
  569. FixedPoint<tRawType, tIntegerBits> a,
  570. FixedPoint<tRawType, tIntegerBits> b) {
  571. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  572. SaturatingAdd(a.raw(), b.raw()));
  573. }
  574. template <typename tRawType, int tIntegerBits>
  575. FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
  576. FixedPoint<tRawType, tIntegerBits> a,
  577. FixedPoint<tRawType, tIntegerBits> b) {
  578. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  579. AddSaturatingIf16Bit(a.raw(), b.raw()));
  580. }
  581. // Conversion to floating-point.
  582. template <typename tRawType, int tIntegerBits>
  583. double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
  584. static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
  585. "not applicable to SIMD types");
  586. typedef FixedPoint<tRawType, tIntegerBits> F;
  587. return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
  588. }
  589. // Rescale changes the number of IntegerBits and updates the underlying
  590. // raw integer value accordingly.
  591. template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
  592. FixedPoint<tRawType, tIntegerBitsDst> Rescale(
  593. FixedPoint<tRawType, tIntegerBitsSrc> x) {
  594. static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
  595. FixedPoint<tRawType, tIntegerBitsDst> result;
  596. result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
  597. return result;
  598. }
  599. // CheckedFixedPointConstant allows to specify fixed-point constants
  600. // initialized as real numbers, in a way that does not compile floating-point
  601. // arithmetic in production code, yet still checks agreement with the
  602. // floating-point expressions when asserts are enabled.
  603. //
  604. // The raw integer value provided is always a int32, encoding a 32-bit
  605. // fixed-point value, regardless of the actual Scalar type. This allows
  606. // writing generic code that applies just as well to the 32-bit and 16-bit
  607. // cases. In the 16-bit case, the raw integer value is internally
  608. // rounding-shifted by 16 bits to the right.
  609. template <typename FixedPointType>
  610. inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
  611. std::int32_t int32_value) {
  612. typedef typename FixedPointType::ScalarRawType ScalarRawType;
  613. static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
  614. return static_cast<ScalarRawType>(
  615. RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
  616. }
  617. #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
  618. template <typename FixedPointType>
  619. FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
  620. double double_value) {
  621. const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
  622. assert(result == FixedPointType::FromDouble(double_value));
  623. return result;
  624. }
  625. #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
  626. ScalarRawInt32Value, DoubleValue) \
  627. (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
  628. gemmlowp::RescaleConstantInitializer<FixedPointType>( \
  629. ScalarRawInt32Value), \
  630. DoubleValue))
  631. #else
  632. #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
  633. ScalarRawInt32Value, DoubleValue) \
  634. (FixedPointType::FromScalarRaw( \
  635. gemmlowp::RescaleConstantInitializer<FixedPointType>( \
  636. ScalarRawInt32Value)))
  637. #endif
  638. // Implementation of exponential function.
  639. // Returns exp(x) for x in [-1/4, 0).
  640. template <typename tRawType>
  641. FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
  642. FixedPoint<tRawType, 0> a) {
  643. typedef FixedPoint<tRawType, 0> F;
  644. const F constant_term =
  645. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
  646. const F constant_1_over_3 =
  647. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
  648. // We're evaluating a Taylor expansion around -1/8, so we do the change of
  649. // variable: x = a + 1/8.
  650. // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
  651. F x = a + F::template ConstantPOT<-3>();
  652. F x2 = x * x;
  653. F x3 = x2 * x;
  654. F x4 = x2 * x2;
  655. F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
  656. F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
  657. SaturatingRoundingMultiplyByPOT<-1>(
  658. ((x4_over_4 + x3) * constant_1_over_3) + x2);
  659. return AddSaturatingIf16Bit(
  660. constant_term,
  661. constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
  662. }
  663. // Returns exp(x) for x < 0.
  664. template <typename tRawType, int tIntegerBits>
  665. FixedPoint<tRawType, 0> exp_on_negative_values(
  666. FixedPoint<tRawType, tIntegerBits> a) {
  667. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  668. typedef FixedPoint<tRawType, 0> ResultF;
  669. static constexpr int kFractionalBits = InputF::kFractionalBits;
  670. static constexpr int kIntegerBits = InputF::kIntegerBits;
  671. const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
  672. InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
  673. InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
  674. ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
  675. Rescale<0>(a_mod_quarter_minus_one_quarter));
  676. tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
  677. #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \
  678. if (kIntegerBits > Exponent) { \
  679. const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \
  680. ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
  681. static constexpr int kShiftAmount = \
  682. kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
  683. result = SelectUsingMask( \
  684. MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
  685. result * kMultiplier, result); \
  686. }
  687. // Constants below are Q0 representations of negative exp fractionals:
  688. GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4)
  689. GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2)
  690. GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1)
  691. GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2)
  692. GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4)
  693. GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8)
  694. GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16)
  695. #undef GEMMLOWP_EXP_BARREL_SHIFTER
  696. static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
  697. if (kIntegerBits > 5) {
  698. const InputF clamp =
  699. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
  700. result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
  701. }
  702. result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
  703. return result;
  704. }
  705. // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
  706. // Returns (1 - x) / (1 + x) for x in (0, 1).
  707. template <typename tRawType>
  708. FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
  709. FixedPoint<tRawType, 0> a) {
  710. typedef FixedPoint<tRawType, 0> F0;
  711. typedef FixedPoint<tRawType, 2> F2;
  712. F0 half_denominator = RoundingHalfSum(a, F0::One());
  713. // Newton-Raphson division
  714. // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  715. // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  716. const F2 constant_48_over_17 =
  717. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  718. const F2 constant_neg_32_over_17 =
  719. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  720. F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  721. for (int i = 0; i < 3; i++) {
  722. F2 half_denominator_times_x = half_denominator * x;
  723. F2 one_minus_half_denominator_times_x =
  724. F2::One() - half_denominator_times_x;
  725. x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  726. }
  727. return Rescale<0>(x - F2::One());
  728. }
  729. // Returns -tanh(x) for x < 0.
  730. template <typename tRawType, int tIntegerBits>
  731. FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
  732. FixedPoint<tRawType, tIntegerBits> a) {
  733. return one_minus_x_over_one_plus_x_for_x_in_0_1(
  734. exp_on_negative_values(ExactMulByPot<1>(a)));
  735. }
  736. // Returns tanh(x) for any x.
  737. template <typename tRawType, int tIntegerBits>
  738. FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
  739. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  740. typedef FixedPoint<tRawType, 0> ResultF;
  741. tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
  742. tRawType mask_if_zero = MaskIfZero(a);
  743. InputF n = SelectUsingMask(mask_if_negative, a, -a);
  744. ResultF t = neg_tanh_on_negative_values(n);
  745. return SelectUsingMask(mask_if_zero, ResultF::Zero(),
  746. SelectUsingMask(mask_if_negative, -t, t));
  747. }
  748. // Implementation of logistic function.
  749. // Returns 1 / (1 + x) for x in (0, 1).
  750. template <typename tRawType>
  751. FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
  752. FixedPoint<tRawType, 0> a) {
  753. typedef FixedPoint<tRawType, 0> F0;
  754. typedef FixedPoint<tRawType, 2> F2;
  755. F0 half_denominator = RoundingHalfSum(a, F0::One());
  756. // Newton-Raphson division
  757. // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  758. // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  759. const F2 constant_48_over_17 =
  760. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  761. const F2 constant_neg_32_over_17 =
  762. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  763. F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  764. for (int i = 0; i < 3; i++) {
  765. F2 half_denominator_times_x = half_denominator * x;
  766. F2 one_minus_half_denominator_times_x =
  767. F2::One() - half_denominator_times_x;
  768. x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  769. }
  770. return Rescale<0>(ExactMulByPot<-1>(x));
  771. }
  772. // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
  773. template <typename tRawType, int tIntegerBits>
  774. FixedPoint<tRawType, 0> logistic_on_positive_values(
  775. FixedPoint<tRawType, tIntegerBits> a) {
  776. return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
  777. }
  778. // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
  779. template <typename tRawType, int tIntegerBits>
  780. FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
  781. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  782. typedef FixedPoint<tRawType, 0> ResultF;
  783. tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
  784. tRawType mask_if_zero = MaskIfZero(a);
  785. InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
  786. ResultF result_if_positive = logistic_on_positive_values(abs_input);
  787. ResultF result_if_negative = ResultF::One() - result_if_positive;
  788. const ResultF one_half =
  789. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
  790. return SelectUsingMask(mask_if_zero, one_half,
  791. SelectUsingMask(mask_if_positive, result_if_positive,
  792. result_if_negative));
  793. }
  794. } // end namespace gemmlowp
  795. #ifdef GEMMLOWP_NEON
  796. #include "./fixedpoint_neon.h"
  797. #elif defined(GEMMLOWP_AVX2)
  798. #include "./fixedpoint_avx.h"
  799. #elif defined(GEMMLOWP_SSE4)
  800. #include "./fixedpoint_sse.h"
  801. #elif defined(GEMMLOWP_MSA)
  802. #include "./fixedpoint_msa.h"
  803. #endif
  804. #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_