| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912 |
- // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // fixedpoint.h: fixed-point arithmetic, with basic operations and
- // a few math functions such as tanh.
- #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
- #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
- #include <algorithm>
- #include <cassert>
- #include <cmath>
- #include <cstdint>
- #include <limits>
- #include "detect_platform.h"
- namespace gemmlowp {
- // Part 1: Low-level integer-arithmetic primitives.
- // The implementations here are generic implementations valid for
- // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
- // (e.g. NEON int32x4_t) may be supported by providing
- // specializations for them in separate files.
- //
- // The purpose of these primitives is two-fold:
- // - They will be used to implement higher-level fixed-point
- // abstractions, namely the FixedPoint class and its arithmetic
- // operators.
- // - They will be directly used to implement some more involved
- // fixed-point computations, e.g. the fixed-point implementation
- // of math functions such as tanh.
- // Some compile-time traits around raw types to handle SIMD aspects:
- // number of lanes, underlying scalar type.
- template <typename tIntegerType>
- struct FixedPointRawTypeTraits {};
- template <>
- struct FixedPointRawTypeTraits<std::int32_t> {
- typedef std::int32_t ScalarRawType;
- static constexpr int kLanes = 1;
- };
- template <>
- struct FixedPointRawTypeTraits<std::int16_t> {
- typedef std::int16_t ScalarRawType;
- static constexpr int kLanes = 1;
- };
- // Returns a SIMD value duplicating a scalar value across all lanes.
- template <typename tRawType>
- tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
- return x;
- }
- // Plain bit-wise AND
- template <typename tIntegerType>
- tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
- return a & b;
- }
- // Plain bit-wise OR
- template <typename tIntegerType>
- tIntegerType BitOr(tIntegerType a, tIntegerType b) {
- return a | b;
- }
- // Plain bit-wise XOR
- template <typename tIntegerType>
- tIntegerType BitXor(tIntegerType a, tIntegerType b) {
- return a ^ b;
- }
- // Plain bit-wise NOT
- template <typename tIntegerType>
- tIntegerType BitNot(tIntegerType a) {
- return ~a;
- }
- // Integer addition. Not saturating. Overflow is undefined behavior.
- template <typename tIntegerType>
- tIntegerType Add(tIntegerType a, tIntegerType b) {
- return a + b;
- }
- // Integer multiplication. Not saturating. Overflow is undefined behavior.
- template <typename tIntegerType>
- tIntegerType Mul(tIntegerType a, tIntegerType b) {
- return a * b;
- }
- // Integer subtraction. Not saturating. Overflow is undefined behavior.
- template <typename tIntegerType>
- tIntegerType Sub(tIntegerType a, tIntegerType b) {
- return a - b;
- }
- // Integer unary negative. Not saturating. Overflow is undefined behavior.
- template <typename tIntegerType>
- tIntegerType Neg(tIntegerType a) {
- return -a;
- }
- // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
- // Negative values are OK. In case of overflow, no Undefined
- // Behavior, but the results are implementation-defined (in practice,
- // they currently are saturated, but we make no commitment to that). The idea
- // is that the caller will want to implement the overflowing cases with
- // saturation with compare-and-mask, so we don't care about the results
- // in the overflow case, we just want to avoid undefined behavior.
- //
- // tIntegerType may be int32 or any narrower signed type.
- template <typename tIntegerType, typename OffsetType>
- tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
- const std::int64_t wide_a = static_cast<std::int64_t>(a);
- const std::int64_t wide_shifted = wide_a * (1 << offset);
- const auto min = std::numeric_limits<tIntegerType>::min();
- const auto max = std::numeric_limits<tIntegerType>::max();
- return wide_shifted < min
- ? min
- : wide_shifted > max ? max
- : static_cast<tIntegerType>(wide_shifted);
- }
- // Integer arithmetic right-shift. Not rounding.
- // Relying on implementation-defined, but in-practice-consistent,
- // C++ compiler behavior.
- template <typename tIntegerType>
- tIntegerType ShiftRight(tIntegerType a, int offset) {
- return a >> offset;
- }
- // Each bit of the result is set to the corresponding bit of either then_val or
- // else_val depending on whether the corresponding bit of if_mask is set.
- // Equivalent to the VBSL instruction in ARM NEON.
- template <typename tIntegerType>
- tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
- tIntegerType else_val) {
- return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
- }
- // For each input scalar, the corresponding bits of the result are set if the
- // input scalar is non-zero.
- template <typename tIntegerType>
- tIntegerType MaskIfNonZero(tIntegerType a) {
- static constexpr tIntegerType zero = 0;
- return a ? BitNot(zero) : zero;
- }
- // For each input scalar, the corresponding bits of the result are set if the
- // input scalar is zero.
- template <typename tIntegerType>
- tIntegerType MaskIfZero(tIntegerType a) {
- return MaskIfNonZero<tIntegerType>(!a);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars are equal.
- template <typename tIntegerType>
- tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a == b);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars are not equal.
- template <typename tIntegerType>
- tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a != b);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars a, b satisfy a > b.
- template <typename tIntegerType>
- tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a > b);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars a, b satisfy a >= b.
- template <typename tIntegerType>
- tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a >= b);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars a, b satisfy a < b.
- template <typename tIntegerType>
- tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a < b);
- }
- // For each pair of input scalars, the corresponding bits of the result are
- // set if the input scalars a, b satisfy a <= b.
- template <typename tIntegerType>
- tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
- return MaskIfNonZero<tIntegerType>(a <= b);
- }
- // Returns true if all of the input scalars are nonzero.
- // This function may currently assume that each of the input scalars has either
- // all or none of its bits set. Otherwise, its behavior is currently undefined.
- template <typename tIntegerType>
- bool All(tIntegerType a) {
- return a;
- }
- // Returns true if any of the input scalars are nonzero.
- // This function may currently assume that each of the input scalars has either
- // all or none of its bits set. Otherwise, its behavior is currently undefined.
- template <typename tIntegerType>
- bool Any(tIntegerType a) {
- return a;
- }
- // Returns (a+b)/2, rounded to the nearest integer.
- // Equivalent to VRHADD in the ARM NEON instruction set.
- template <typename IntegerType>
- IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
- static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
- return a;
- }
- template <>
- inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
- std::int64_t a64 = a;
- std::int64_t b64 = b;
- std::int64_t sum = a64 + b64;
- std::int64_t sign = sum >= 0 ? 1 : -1;
- return static_cast<std::int32_t>((sum + sign) / 2);
- }
- template <>
- inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
- std::int32_t a32 = a;
- std::int32_t b32 = b;
- std::int32_t sum = a32 + b32;
- std::int32_t sign = sum >= 0 ? 1 : -1;
- return static_cast<std::int16_t>((sum + sign) / 2);
- }
- template <typename IntegerType>
- IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
- static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
- return a;
- }
- // So far this is only needed for int16.
- template <>
- inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
- std::int32_t a32 = a;
- std::int32_t b32 = b;
- std::int32_t sum = a32 + b32;
- return static_cast<std::int16_t>(
- std::min(static_cast<std::int32_t>(32767),
- std::max(static_cast<std::int32_t>(-32768), sum)));
- }
- template <>
- inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
- std::int16_t a16 = a;
- std::int16_t b16 = b;
- std::int16_t sum = a16 + b16;
- return static_cast<std::int8_t>(std::min(
- static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
- std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
- }
- // Returns a+b, saturating if the integers are 16bit or narrower,
- // otherwise just a plain addition.
- template <typename IntegerType, bool Is16Bit>
- struct AddSaturatingIf16BitImpl {
- static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
- };
- template <typename IntegerType>
- struct AddSaturatingIf16BitImpl<IntegerType, true> {
- static IntegerType Run(IntegerType a, IntegerType b) {
- return SaturatingAdd(a, b);
- }
- };
- template <typename IntegerType>
- IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
- using ScalarType =
- typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
- b);
- }
- // Returns the integer that represents the product of two fixed-point
- // numbers, interpreting all integers as fixed-point values in the
- // interval [-1, 1), rounding to the nearest value, and saturating
- // -1 * -1 to the maximum value (since 1 is not in the half-open
- // interval [-1, 1)).
- //
- // [The explanation below specializes to std::int32_t for example purpose.]
- //
- // The mapping between IntegerType and the interval [-1, 1) is unique and
- // implied by IntegerType, which is assumed to be signed. For example,
- // for IntegerType==std::int32_t, the mapping is
- // real_value = integer_value / 2^31.
- // So in this case, and leaving aside rounding and saturating, this
- // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
- // (a * b) / 2^31.
- //
- // The 'doubling' part in the name of this function comes from the fact that
- // this operation is very close to a "multiply-high" operation, keeping only
- // the top half bits, except that that would be effectively computing
- // (a * b) / 2^32,
- // so here we are computing 2x that, since
- // 1/2^31 = 2 * 1/2^32.
- // The idea is to use all of the available 32 bits in the destination int32
- // value.
- //
- // [End of the explanation specializing to int32.]
- //
- // This is equivalent to the VQRDMULH instruction in ARM NEON.
- template <typename IntegerType>
- IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
- static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
- (void)b;
- return a;
- }
- // This function implements the same computation as the ARMv7 NEON VQRDMULH
- // instruction.
- template <>
- inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
- std::int32_t b) {
- bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
- std::int64_t a_64(a);
- std::int64_t b_64(b);
- std::int64_t ab_64 = a_64 * b_64;
- std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
- std::int32_t ab_x2_high32 =
- static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
- return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
- }
- template <>
- inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
- std::int16_t b) {
- bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
- std::int32_t a_32(a);
- std::int32_t b_32(b);
- std::int32_t ab_32 = a_32 * b_32;
- std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
- std::int16_t ab_x2_high16 =
- static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
- return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
- }
- // Correctly-rounded-to-nearest division by a power-of-two.
- // Also known as a rounding arithmetic right shift.
- template <typename IntegerType, typename ExponentType>
- inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
- assert(exponent >= 0);
- assert(exponent <= 31);
- const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
- const IntegerType zero = Dup<IntegerType>(0);
- const IntegerType one = Dup<IntegerType>(1);
- const IntegerType remainder = BitAnd(x, mask);
- const IntegerType threshold =
- Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
- return Add(ShiftRight(x, exponent),
- BitAnd(MaskIfGreaterThan(remainder, threshold), one));
- }
- // Returns the product of a run-time integer value by a compile-time power
- // of two, with either a positive exponent (equivalent to an arithmetic
- // left shift, saturating) or a negative exponent (equivalent to an arithmetic
- // right shift, rounding to nearest).
- template <int Exponent, typename IntegerType,
- int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
- struct ImplSaturatingRoundingMultiplyByPOT {};
- template <int Exponent, typename IntegerType>
- struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
- static IntegerType eval(IntegerType x) { return x; }
- };
- template <int Exponent, typename IntegerType>
- struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
- static IntegerType eval(IntegerType x) {
- using ScalarIntegerType =
- typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- const IntegerType min =
- Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
- const IntegerType max =
- Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
- const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
- const std::int32_t threshold =
- ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
- const IntegerType positive_mask =
- MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
- const IntegerType negative_mask =
- MaskIfLessThan(x, Dup<IntegerType>(-threshold));
- IntegerType result = ShiftLeft(x, Exponent);
- result = SelectUsingMask(positive_mask, max, result);
- result = SelectUsingMask(negative_mask, min, result);
- return result;
- }
- };
- template <int Exponent, typename IntegerType>
- struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
- static IntegerType eval(IntegerType x) {
- return RoundingDivideByPOT<IntegerType>(x, -Exponent);
- }
- };
- template <int Exponent, typename IntegerType>
- IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
- return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
- }
- // Part 2: the FixedPoint class.
- // A FixedPoint object represents a fixed-point value stored in the underlying
- // integer type tRawType, if tRawType is a plain scalar integer type.
- // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
- // case a FixedPoint object represents a corresponding SIMD vector of fixed
- // point values.
- //
- // tIntegerBits describes the range of the fixed-point format: if
- // tIntegerBits == m then the range of representable values is the half-open
- // interval [-2^m; 2^m) where the open boundary on the right side means that
- // 2^m is not representable (how close the maximum representable value is to
- // it, depends on bit-depth of tRawType).
- //
- // In "Q format notation",
- // https://en.wikipedia.org/wiki/Q_(number_format)
- // we are describing the format
- // Qm.n
- // where
- // m = tIntegerBits
- // and
- // n = NumberOfBits(tRawType) - (m + 1)
- // Note that the (m + 1) in the above line is because we adopt the convention
- // that we count the integer bits exclusively of the sign bit; so (m + 1) is
- // the total number of integer bits inclusive of the sign bit.
- //
- // Accordingly, the number of integral representable values in our range
- // [-2^m ; 2^m)
- // is equal to 2^(m+1).
- template <typename tRawType, int tIntegerBits>
- class FixedPoint {
- public:
- typedef tRawType RawType;
- typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
- typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
- static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
- static constexpr int kIntegerBits = tIntegerBits;
- static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
- static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
- "bad IntegerBits");
- typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
- static const ScalarRawType ScalarRawMin() {
- return std::numeric_limits<ScalarRawType>::min();
- }
- static const ScalarRawType ScalarRawMax() {
- return std::numeric_limits<ScalarRawType>::max();
- }
- static const ScalarRawType RawMin() {
- return VectorFromScalar(ScalarRawMin());
- }
- static const ScalarRawType RawMax() {
- return VectorFromScalar(ScalarRawMax());
- }
- static FixedPoint FromRaw(RawType x) {
- FixedPoint retval;
- retval.raw() = x;
- return retval;
- }
- static FixedPoint FromScalarRaw(ScalarRawType x) {
- FixedPoint retval;
- retval.raw() = Dup<RawType>(x);
- return retval;
- }
- static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
- return FromScalarRaw(x.raw());
- }
- template <int Exponent>
- static FixedPoint ConstantPOT() {
- static constexpr int kOffset = kFractionalBits + Exponent;
- static_assert(
- kOffset < 31,
- "Constant not exactly representable in this fixed-point format");
- return FromScalarRaw(ScalarRawType(1) << kOffset);
- }
- static FixedPoint Zero() { return FromScalarRaw(0); }
- static FixedPoint One() {
- return FromScalarRaw(
- kIntegerBits == 0
- ? ScalarRawMax()
- : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
- }
- static FixedPoint FromDouble(double x) {
- const double min_bound = static_cast<double>(ScalarRawMin());
- const double max_bound = static_cast<double>(ScalarRawMax());
- return FromScalarRaw(static_cast<ScalarRawType>(std::min(
- std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
- min_bound),
- max_bound)));
- }
- RawType raw() const { return i_; }
- RawType& raw() { return i_; }
- private:
- RawType i_;
- };
- // Part 3: implementation of arithmetic operators for the
- // FixedPoint class, and a few related functions.
- // A FixedPoint multiplication is just a
- // SaturatingRoundingDoublingHighMul operation on the underlying
- // raw integer values. The IntegerBits simply add up, as is obvious
- // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
- template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
- FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
- FixedPoint<tRawType, tIntegerBits_a> a,
- FixedPoint<tRawType, tIntegerBits_b> b) {
- FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
- c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
- return c;
- }
- // Tweaking IntegerBits gives exact multiplication by a power of two.
- template <int tExponent, typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
- FixedPoint<tRawType, tIntegerBits> a) {
- FixedPoint<tRawType, tExponent + tIntegerBits> c;
- c.raw() = a.raw();
- return c;
- }
- // If we want to leave IntegerBits fixed, then multiplication
- // by a power of two has to be saturating/rounding, not exact anymore.
- template <int tExponent, typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
- FixedPoint<tRawType, tIntegerBits> a) {
- return FixedPoint<tRawType, tIntegerBits>::FromRaw(
- SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
- }
- // Generic arithmetic operators.
- #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \
- template <typename tRawType, int tIntegerBits> \
- FixedPoint<tRawType, tIntegerBits> FuncName( \
- FixedPoint<tRawType, tIntegerBits> a) { \
- return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
- }
- #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
- template <typename tRawType, int tIntegerBits> \
- FixedPoint<tRawType, tIntegerBits> FuncName( \
- FixedPoint<tRawType, tIntegerBits> a, \
- FixedPoint<tRawType, tIntegerBits> b) { \
- return FixedPoint<tRawType, tIntegerBits>::FromRaw( \
- ImplFuncName(a.raw(), b.raw())); \
- }
- MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
- MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
- MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
- MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
- MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
- MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
- MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
- MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
- #undef MAKE_FIXEDPOINT_UNARY_FUNC
- #undef MAKE_FIXEDPOINT_BINARY_FUNC
- #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \
- template <typename tRawType, int tIntegerBits> \
- tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
- return FuncName(a.raw()); \
- }
- #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
- template <typename tRawType, int tIntegerBits> \
- tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \
- FixedPoint<tRawType, tIntegerBits> b) { \
- return FuncName(a.raw(), b.raw()); \
- }
- MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
- MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
- MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
- #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
- #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
- tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
- FixedPoint<tRawType, tIntegerBits> else_val) {
- return FixedPoint<tRawType, tIntegerBits>::FromRaw(
- SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
- }
- template <typename tRawType, int tIntegerBits>
- bool operator==(FixedPoint<tRawType, tIntegerBits> a,
- FixedPoint<tRawType, tIntegerBits> b) {
- return All(MaskIfEqual(a.raw(), b.raw()));
- }
- template <typename tRawType, int tIntegerBits>
- bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
- FixedPoint<tRawType, tIntegerBits> b) {
- return !(a == b);
- }
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
- FixedPoint<tRawType, tIntegerBits> a,
- FixedPoint<tRawType, tIntegerBits> b) {
- return FixedPoint<tRawType, tIntegerBits>::FromRaw(
- SaturatingAdd(a.raw(), b.raw()));
- }
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
- FixedPoint<tRawType, tIntegerBits> a,
- FixedPoint<tRawType, tIntegerBits> b) {
- return FixedPoint<tRawType, tIntegerBits>::FromRaw(
- AddSaturatingIf16Bit(a.raw(), b.raw()));
- }
- // Conversion to floating-point.
- template <typename tRawType, int tIntegerBits>
- double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
- static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
- "not applicable to SIMD types");
- typedef FixedPoint<tRawType, tIntegerBits> F;
- return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
- }
- // Rescale changes the number of IntegerBits and updates the underlying
- // raw integer value accordingly.
- template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
- FixedPoint<tRawType, tIntegerBitsDst> Rescale(
- FixedPoint<tRawType, tIntegerBitsSrc> x) {
- static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
- FixedPoint<tRawType, tIntegerBitsDst> result;
- result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
- return result;
- }
- // CheckedFixedPointConstant allows to specify fixed-point constants
- // initialized as real numbers, in a way that does not compile floating-point
- // arithmetic in production code, yet still checks agreement with the
- // floating-point expressions when asserts are enabled.
- //
- // The raw integer value provided is always a int32, encoding a 32-bit
- // fixed-point value, regardless of the actual Scalar type. This allows
- // writing generic code that applies just as well to the 32-bit and 16-bit
- // cases. In the 16-bit case, the raw integer value is internally
- // rounding-shifted by 16 bits to the right.
- template <typename FixedPointType>
- inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
- std::int32_t int32_value) {
- typedef typename FixedPointType::ScalarRawType ScalarRawType;
- static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
- return static_cast<ScalarRawType>(
- RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
- }
- #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
- template <typename FixedPointType>
- FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
- double double_value) {
- const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
- assert(result == FixedPointType::FromDouble(double_value));
- return result;
- }
- #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
- ScalarRawInt32Value, DoubleValue) \
- (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
- gemmlowp::RescaleConstantInitializer<FixedPointType>( \
- ScalarRawInt32Value), \
- DoubleValue))
- #else
- #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
- ScalarRawInt32Value, DoubleValue) \
- (FixedPointType::FromScalarRaw( \
- gemmlowp::RescaleConstantInitializer<FixedPointType>( \
- ScalarRawInt32Value)))
- #endif
- // Implementation of exponential function.
- // Returns exp(x) for x in [-1/4, 0).
- template <typename tRawType>
- FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
- FixedPoint<tRawType, 0> a) {
- typedef FixedPoint<tRawType, 0> F;
- const F constant_term =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
- const F constant_1_over_3 =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
- // We're evaluating a Taylor expansion around -1/8, so we do the change of
- // variable: x = a + 1/8.
- // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
- F x = a + F::template ConstantPOT<-3>();
- F x2 = x * x;
- F x3 = x2 * x;
- F x4 = x2 * x2;
- F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
- F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
- SaturatingRoundingMultiplyByPOT<-1>(
- ((x4_over_4 + x3) * constant_1_over_3) + x2);
- return AddSaturatingIf16Bit(
- constant_term,
- constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
- }
- // Returns exp(x) for x < 0.
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, 0> exp_on_negative_values(
- FixedPoint<tRawType, tIntegerBits> a) {
- typedef FixedPoint<tRawType, tIntegerBits> InputF;
- typedef FixedPoint<tRawType, 0> ResultF;
- static constexpr int kFractionalBits = InputF::kFractionalBits;
- static constexpr int kIntegerBits = InputF::kIntegerBits;
- const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
- InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
- InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
- ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
- Rescale<0>(a_mod_quarter_minus_one_quarter));
- tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
- #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \
- if (kIntegerBits > Exponent) { \
- const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \
- ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
- static constexpr int kShiftAmount = \
- kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
- result = SelectUsingMask( \
- MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
- result * kMultiplier, result); \
- }
- // Constants below are Q0 representations of negative exp fractionals:
- GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4)
- GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2)
- GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1)
- GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2)
- GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4)
- GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8)
- GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16)
- #undef GEMMLOWP_EXP_BARREL_SHIFTER
- static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
- if (kIntegerBits > 5) {
- const InputF clamp =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
- result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
- }
- result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
- return result;
- }
- // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
- // Returns (1 - x) / (1 + x) for x in (0, 1).
- template <typename tRawType>
- FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
- FixedPoint<tRawType, 0> a) {
- typedef FixedPoint<tRawType, 0> F0;
- typedef FixedPoint<tRawType, 2> F2;
- F0 half_denominator = RoundingHalfSum(a, F0::One());
- // Newton-Raphson division
- // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
- // Refer to that page for the logic behind the 48/17 and 32/17 constants.
- const F2 constant_48_over_17 =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
- const F2 constant_neg_32_over_17 =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
- F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
- for (int i = 0; i < 3; i++) {
- F2 half_denominator_times_x = half_denominator * x;
- F2 one_minus_half_denominator_times_x =
- F2::One() - half_denominator_times_x;
- x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
- }
- return Rescale<0>(x - F2::One());
- }
- // Returns -tanh(x) for x < 0.
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
- FixedPoint<tRawType, tIntegerBits> a) {
- return one_minus_x_over_one_plus_x_for_x_in_0_1(
- exp_on_negative_values(ExactMulByPot<1>(a)));
- }
- // Returns tanh(x) for any x.
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
- typedef FixedPoint<tRawType, tIntegerBits> InputF;
- typedef FixedPoint<tRawType, 0> ResultF;
- tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
- tRawType mask_if_zero = MaskIfZero(a);
- InputF n = SelectUsingMask(mask_if_negative, a, -a);
- ResultF t = neg_tanh_on_negative_values(n);
- return SelectUsingMask(mask_if_zero, ResultF::Zero(),
- SelectUsingMask(mask_if_negative, -t, t));
- }
- // Implementation of logistic function.
- // Returns 1 / (1 + x) for x in (0, 1).
- template <typename tRawType>
- FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
- FixedPoint<tRawType, 0> a) {
- typedef FixedPoint<tRawType, 0> F0;
- typedef FixedPoint<tRawType, 2> F2;
- F0 half_denominator = RoundingHalfSum(a, F0::One());
- // Newton-Raphson division
- // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
- // Refer to that page for the logic behind the 48/17 and 32/17 constants.
- const F2 constant_48_over_17 =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
- const F2 constant_neg_32_over_17 =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
- F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
- for (int i = 0; i < 3; i++) {
- F2 half_denominator_times_x = half_denominator * x;
- F2 one_minus_half_denominator_times_x =
- F2::One() - half_denominator_times_x;
- x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
- }
- return Rescale<0>(ExactMulByPot<-1>(x));
- }
- // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, 0> logistic_on_positive_values(
- FixedPoint<tRawType, tIntegerBits> a) {
- return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
- }
- // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
- template <typename tRawType, int tIntegerBits>
- FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
- typedef FixedPoint<tRawType, tIntegerBits> InputF;
- typedef FixedPoint<tRawType, 0> ResultF;
- tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
- tRawType mask_if_zero = MaskIfZero(a);
- InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
- ResultF result_if_positive = logistic_on_positive_values(abs_input);
- ResultF result_if_negative = ResultF::One() - result_if_positive;
- const ResultF one_half =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
- return SelectUsingMask(mask_if_zero, one_half,
- SelectUsingMask(mask_if_positive, result_if_positive,
- result_if_negative));
- }
- } // end namespace gemmlowp
- #ifdef GEMMLOWP_NEON
- #include "./fixedpoint_neon.h"
- #elif defined(GEMMLOWP_AVX2)
- #include "./fixedpoint_avx.h"
- #elif defined(GEMMLOWP_SSE4)
- #include "./fixedpoint_sse.h"
- #elif defined(GEMMLOWP_MSA)
- #include "./fixedpoint_msa.h"
- #endif
- #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
|