arm_softmax_q7.c 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_softmax_q7.c
  21. * Description: Q7 softmax function
  22. *
  23. * $Date: February 27, 2020
  24. * $Revision: V.1.0.1
  25. *
  26. * Target Processor: Cortex-M cores
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_math.h"
  30. #include "arm_nnfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup Softmax
  36. * @{
  37. */
  38. /**
  39. * @brief Q7 softmax function
  40. * @param[in] vec_in pointer to input vector
  41. * @param[in] dim_vec input vector dimention
  42. * @param[out] p_out pointer to output vector
  43. *
  44. * @details
  45. *
  46. * Here, instead of typical natural logarithm e based softmax, we use
  47. * 2-based softmax here, i.e.,:
  48. *
  49. * y_i = 2^(x_i) / sum(2^x_j)
  50. *
  51. * The relative output will be different here.
  52. * But mathematically, the gradient will be the same
  53. * with a log(2) scaling factor.
  54. *
  55. * If we compare the position of the max value in output of this
  56. * function with a reference float32 softmax (and thus using exp)
  57. * we see that the position of the max value is sometimes different.
  58. *
  59. * If we do statistics on lot of input vectors we can compute
  60. * an average error rate in percent. It is the percent of time
  61. * that the max will be at a position different from the one
  62. * computed with a reference float32 implementation.
  63. *
  64. * This average error rate is dependent on the vector size.
  65. * We have:
  66. *
  67. * Average error rate in percent = -0.555548 + 0.246918 dim_vec
  68. * Variance of the error rate = -0.0112281 + 0.0382476 dim_vec
  69. *
  70. *
  71. */
  72. #define Q7BITS 8
  73. #define LOG2Q7BITS 3
  74. void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out )
  75. {
  76. #if defined (ARM_MATH_DSP)
  77. q31_t sum;
  78. int16_t i;
  79. uint8_t shift;
  80. q15_t base;
  81. uint16_t blkCnt;
  82. q31_t in,in1,in2;
  83. q31_t out1, out2;
  84. q31_t baseV;
  85. q31_t shiftV;
  86. const q31_t pad=0x0d0d0d0d;
  87. const q7_t *pIn=vec_in;
  88. base = -128;
  89. /* We first search for the maximum */
  90. for (i = 0; i < dim_vec; i++)
  91. {
  92. if (vec_in[i] > base)
  93. {
  94. base = vec_in[i];
  95. }
  96. }
  97. /*
  98. * So the base is set to max-8, meaning
  99. * that we ignore really small values.
  100. * anyway, they will be 0 after shrinking to q7_t.
  101. */
  102. base = base - Q7BITS;
  103. baseV = ((base & 0x0FF) << 24) | ((base & 0x0FF) << 16) | ((base & 0x0FF) << 8) | ((base & 0x0FF));
  104. sum = 0;
  105. blkCnt = dim_vec >> 2;
  106. while(blkCnt)
  107. {
  108. in=arm_nn_read_q7x4_ia(&pIn);
  109. in=__SSUB8(in,baseV);
  110. in1 = __SXTB16(__ROR(in, 8));
  111. /* extend remaining two q7_t values to q15_t values */
  112. in2 = __SXTB16(in);
  113. #ifndef ARM_MATH_BIG_ENDIAN
  114. out2 = __PKHTB(in1, in2, 16);
  115. out1 = __PKHBT(in2, in1, 16);
  116. #else
  117. out1 = __PKHTB(in1, in2, 16);
  118. out2 = __PKHBT(in2, in1, 16);
  119. #endif
  120. shiftV = __USAT16(out1,LOG2Q7BITS);
  121. sum += 0x1 << (shiftV & 0x0FF);
  122. sum += 0x1 << ((shiftV >> 16) & 0x0FF);
  123. shiftV = __USAT16(out2,LOG2Q7BITS);
  124. sum += 0x1 << (shiftV & 0x0FF);
  125. sum += 0x1 << ((shiftV >> 16) & 0x0FF);
  126. blkCnt--;
  127. }
  128. blkCnt = dim_vec & 3;
  129. while(blkCnt)
  130. {
  131. shift = (uint8_t)__USAT(*pIn++ - base, LOG2Q7BITS);
  132. sum += 0x1 << shift;
  133. blkCnt--;
  134. }
  135. /* This is effectively (0x1 << 20) / sum */
  136. int output_base = (1 << 20) / sum;
  137. pIn=vec_in;
  138. blkCnt = dim_vec >> 2;
  139. while(blkCnt)
  140. {
  141. /* Here minimum value of 13+base-vec_in[i] will be 5 */
  142. in=arm_nn_read_q7x4_ia(&pIn);
  143. in=__SSUB8(pad,in);
  144. in=__SADD8(in,baseV);
  145. in1 = __SXTB16(__ROR(in, 8));
  146. /* extend remaining two q7_t values to q15_t values */
  147. in2 = __SXTB16(in);
  148. #ifndef ARM_MATH_BIG_ENDIAN
  149. out2 = __PKHTB(in1, in2, 16);
  150. out1 = __PKHBT(in2, in1, 16);
  151. #else
  152. out1 = __PKHTB(in1, in2, 16);
  153. out2 = __PKHBT(in2, in1, 16);
  154. #endif
  155. shiftV = __USAT16(out1,5);
  156. *p_out++ = (q7_t) __SSAT((output_base >> (shiftV & 0x0FF)), 8);
  157. *p_out++ = (q7_t) __SSAT((output_base >> ((shiftV >> 16) & 0x0FF)), 8);
  158. shiftV = __USAT16(out2,5);
  159. *p_out++ = (q7_t) __SSAT((output_base >> (shiftV & 0x0FF)), 8);
  160. *p_out++ = (q7_t) __SSAT((output_base >> ((shiftV >> 16) & 0x0FF)), 8);
  161. blkCnt --;
  162. }
  163. blkCnt = dim_vec & 3;
  164. while(blkCnt)
  165. {
  166. /* Here minimum value of 13+base-vec_in[i] will be 5 */
  167. shift = (uint8_t)__USAT(13 + base - *pIn++, 5);
  168. *p_out++ = (q7_t) __SSAT((output_base >> shift), 8);
  169. blkCnt --;
  170. }
  171. #else
  172. q31_t sum;
  173. int16_t i;
  174. uint8_t shift;
  175. q15_t base;
  176. base = -128;
  177. /* We first search for the maximum */
  178. for (i = 0; i < dim_vec; i++)
  179. {
  180. if (vec_in[i] > base)
  181. {
  182. base = vec_in[i];
  183. }
  184. }
  185. /*
  186. * So the base is set to max-8, meaning
  187. * that we ignore really small values.
  188. * anyway, they will be 0 after shrinking to q7_t.
  189. */
  190. base = base - Q7BITS;
  191. sum = 0;
  192. for (i = 0; i < dim_vec; i++)
  193. {
  194. shift = (uint8_t)__USAT(vec_in[i] - base, LOG2Q7BITS);
  195. sum += 0x1 << shift;
  196. }
  197. /* This is effectively (0x1 << 20) / sum */
  198. int output_base = (1 << 20) / sum;
  199. for (i = 0; i < dim_vec; i++)
  200. {
  201. /* Here minimum value of 13+base-vec_in[i] will be 5 */
  202. shift = (uint8_t)__USAT(13 + base - vec_in[i], 5);
  203. p_out[i] = (q7_t) __SSAT((output_base >> shift), 8);
  204. }
  205. #endif
  206. }
  207. /**
  208. * @} end of Softmax group
  209. */