|
|
@@ -40,9 +40,13 @@ extern "C"
|
|
|
|
|
|
#define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
|
|
|
#define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
|
|
|
+#define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
|
|
|
+#define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
|
|
|
+#define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
|
|
|
|
|
|
#define MAX(A,B) ((A) > (B) ? (A) : (B))
|
|
|
#define MIN(A,B) ((A) < (B) ? (A) : (B))
|
|
|
+#define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
|
|
|
|
|
|
/**
|
|
|
* @brief Union for SIMD access of q31/q15/q7 types
|
|
|
@@ -343,6 +347,13 @@ void arm_nn_mult_q7(
|
|
|
#define NN_ROUND(out_shift) 0
|
|
|
#endif
|
|
|
|
|
|
+// Macros for shortening quantization functions' names and avoid long lines
|
|
|
+#define MUL_SAT(a, b) arm_nn_sat_doubling_high_mult((a), (b))
|
|
|
+#define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
|
|
|
+#define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
|
|
|
+#define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
|
|
|
+#define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
|
|
|
+
|
|
|
/**
|
|
|
* @brief Saturating doubling high multiply. Result matches
|
|
|
* NEON instruction VQRDMULH.
|
|
|
@@ -467,6 +478,64 @@ __STATIC_FORCEINLINE int32x4_t arm_mve_requantize(const int32x4_t val, const q31
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
+// @note The following functions are used only for softmax layer, scaled bits = 5 assumed
|
|
|
+
|
|
|
+__STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
|
|
|
+{
|
|
|
+ int32_t mask = 0;
|
|
|
+ int32_t shift = 24;
|
|
|
+
|
|
|
+ const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
|
|
|
+ const int32_t remainder = val_mod_minus_quarter - val;
|
|
|
+ const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
|
|
|
+ const int32_t x2 = MUL_SAT(x, x);
|
|
|
+
|
|
|
+ int32_t result = 1895147668 + MUL_SAT(1895147668, x +
|
|
|
+ DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
|
|
|
+
|
|
|
+#define SELECT_IF_NON_ZERO(x) \
|
|
|
+{ \
|
|
|
+ mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
|
|
|
+ result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
|
|
|
+}
|
|
|
+
|
|
|
+ SELECT_IF_NON_ZERO(1672461947)
|
|
|
+ SELECT_IF_NON_ZERO(1302514674)
|
|
|
+ SELECT_IF_NON_ZERO(790015084)
|
|
|
+ SELECT_IF_NON_ZERO(290630308)
|
|
|
+ SELECT_IF_NON_ZERO(39332535)
|
|
|
+ SELECT_IF_NON_ZERO(720401)
|
|
|
+ SELECT_IF_NON_ZERO(242)
|
|
|
+
|
|
|
+#undef SELECT_IF_NON_ZERO
|
|
|
+
|
|
|
+ mask = MASK_IF_ZERO(val);
|
|
|
+ return SELECT_USING_MASK(mask, Q31_MAX, result);
|
|
|
+}
|
|
|
+
|
|
|
+__STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
|
|
|
+{
|
|
|
+ const int32_t thresh = ((1 << (31 - exp)) - 1);
|
|
|
+ int32_t result = val << exp;
|
|
|
+ result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), Q31_MAX, result);
|
|
|
+ result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), Q31_MIN, result);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+__STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
|
|
|
+{
|
|
|
+ const int64_t sum = (int64_t)val + (int64_t)Q31_MAX;
|
|
|
+ const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
|
|
|
+ int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
|
|
|
+
|
|
|
+ const int32_t shift = (1 << 29);
|
|
|
+ x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
|
|
|
+ x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
|
|
|
+ x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
|
|
|
+
|
|
|
+ return MUL_POW2(x, 1);
|
|
|
+}
|
|
|
+
|
|
|
#ifdef __cplusplus
|
|
|
}
|
|
|
#endif
|