arm_nnsupportfunctions.h 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nnsupportfunctions.h
  21. * Description: Public header file of support functions for CMSIS NN Library
  22. *
  23. * $Date: March 17, 2020
  24. * $Revision: V.4.0.3
  25. *
  26. * Target Processor: Cortex-M cores
  27. * -------------------------------------------------------------------- */
  28. #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
  29. #define _ARM_NNSUPPORTFUNCTIONS_H_
  30. #include "arm_math.h"
  31. #include "arm_common_tables.h"
  32. #ifdef __cplusplus
  33. extern "C"
  34. {
  35. #endif
  36. #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
  37. #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
  38. #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
  39. #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
  40. #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
  41. #define MAX(A,B) ((A) > (B) ? (A) : (B))
  42. #define MIN(A,B) ((A) < (B) ? (A) : (B))
  43. #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
  44. /**
  45. * @brief Union for SIMD access of q31/q15/q7 types
  46. */
  47. union arm_nnword
  48. {
  49. q31_t word;
  50. /**< q31 type */
  51. q15_t half_words[2];
  52. /**< q15 type */
  53. q7_t bytes[4];
  54. /**< q7 type */
  55. };
  56. /**
  57. * @brief Struct for specifying activation function types
  58. *
  59. */
  60. typedef enum
  61. {
  62. ARM_SIGMOID = 0,
  63. /**< Sigmoid activation function */
  64. ARM_TANH = 1,
  65. /**< Tanh activation function */
  66. } arm_nn_activation_type;
  67. /**
  68. * @defgroup nndata_convert Neural Network Data Conversion Functions
  69. *
  70. * Perform data type conversion in-between neural network operations
  71. *
  72. */
  73. /**
  74. * @brief Converts the elements of the q7 vector to q15 vector without left-shift
  75. * @param[in] *pSrc points to the q7 input vector
  76. * @param[out] *pDst points to the q15 output vector
  77. * @param[in] blockSize length of the input vector
  78. *
  79. */
  80. void arm_q7_to_q15_no_shift(const q7_t * pSrc, q15_t * pDst, uint32_t blockSize);
  81. /**
  82. * @brief Non-saturating addition of elements of a q7 vector
  83. * @param[in] *input Pointer to the q7 input vector
  84. * @param[out] *output Pointer to the q31 output variable.
  85. * @param[in] block_size length of the input vector
  86. * \par Description:
  87. *
  88. * 2^24 samples can be added without saturating the result.
  89. *
  90. * The equation used for the conversion process is:
  91. *
  92. * <pre>
  93. * sum = input[0] + input[1] + .. + input[block_size -1]
  94. * </pre>
  95. *
  96. * */
  97. void arm_nn_add_q7(const q7_t *input, q31_t *output, uint32_t block_size);
  98. /**
  99. * @brief Converts the elements of the q7 vector to reordered q15 vector without left-shift
  100. * @param[in] *pSrc points to the q7 input vector
  101. * @param[out] *pDst points to the q15 output vector
  102. * @param[in] blockSize length of the input vector
  103. * @return none.
  104. *
  105. */
  106. void arm_q7_to_q15_reordered_no_shift(const q7_t * pSrc, q15_t * pDst, uint32_t blockSize);
  107. /**
  108. * @brief Converts the elements from a q7 vector to a q15 vector with an added offset
  109. * @param[in] src pointer to the q7 input vector
  110. * @param[out] dst pointer to the q15 output vector
  111. * @param[in] block_size length of the input vector
  112. * @param[in] offset q7 offset to be added to each input vector element.
  113. *
  114. * \par Description:
  115. *
  116. * The equation used for the conversion process is:
  117. *
  118. * <pre>
  119. * dst[n] = (q15_t) src[n] + offset; 0 <= n < block_size.
  120. * </pre>
  121. *
  122. */
  123. void arm_q7_to_q15_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  124. /**
  125. * @brief Converts the elements of the q7 vector to reordered q15 vector with an added offset
  126. * @param[in] src pointer to the q7 input vector
  127. * @param[out] dst pointer to the q15 output vector
  128. * @param[in] block_size length of the input vector
  129. * @param[in] offset offset to be added to each input vector element.
  130. * @return none.
  131. *
  132. * @details This function does the q7 to q15 expansion with re-ordering of bytes. Re-ordering is a consequence of
  133. * the sign extension intrinsic(DSP extension). The tail (i.e., last (N % 4) elements) retains its original
  134. * order.
  135. *
  136. */
  137. void arm_q7_to_q15_reordered_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  138. /**
  139. * @brief Converts the elements from a q7 vector and accumulate to a q15 vector
  140. * @param[in] *src points to the q7 input vector
  141. * @param[out] *dst points to the q15 output vector
  142. * @param[in] block_size length of the input vector
  143. *
  144. * \par Description:
  145. *
  146. * The equation used for the conversion process is:
  147. *
  148. * <pre>
  149. * dst[n] += (q15_t) src[n] ; 0 <= n < block_size.
  150. * </pre>
  151. *
  152. */
  153. void arm_nn_accumulate_q7_to_q15(q15_t *dst, const q7_t *src, uint32_t block_size);
  154. /**
  155. * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
  156. * @param[in] row pointer to row
  157. * @param[in] col pointer to im2col buffer, always consists of 2 columns.
  158. * @param[in] num_ch number of channels
  159. * @param[in] out_shift pointer to per output channel requantization shift parameter.
  160. * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
  161. * @param[in] out_offset output tensor offset.
  162. * @param[in] activation_min minimum value to clamp the output to. Range : int8
  163. * @param[in] activation_max maximum value to clamp the output to. Range : int8
  164. * @param[in] kernel_size number of elements in one column.
  165. * @param[in] output_bias per output channel bias. Range : int32
  166. * @param[out] out pointer to output
  167. * @return The function returns one of the two
  168. * 1. The incremented output pointer for a successful operation or
  169. * 2. NULL if implementation is not available.
  170. *
  171. * @details Supported framework: TensorFlow Lite micro.
  172. */
  173. q7_t *arm_nn_depthwise_conv_s8_core(const q7_t *row,
  174. const q15_t *col,
  175. const uint16_t num_ch,
  176. const int32_t *out_shift,
  177. const int32_t *out_mult,
  178. const int32_t out_offset,
  179. const int32_t activation_min,
  180. const int32_t activation_max,
  181. const uint16_t kernel_size,
  182. const int32_t *const output_bias,
  183. q7_t *out);
  184. /**
  185. * @brief General Matrix-multiplication function with per-channel requantization.
  186. * @param[in] input_row pointer to row operand
  187. * @param[in] input_col pointer to col operand
  188. * @param[in] output_ch number of rows of input_row
  189. * @param[in] col_batches number of column batches. Range: 1 to 4
  190. * @param[in] output_shift pointer to per output channel requantization shift parameter.
  191. * @param[in] output_mult pointer to per output channel requantization multiplier parameter.
  192. * @param[in] out_offset output tensor offset.
  193. * @param[in] col_offset input tensor(col) offset.
  194. * @param[in] row_offset kernel offset(row). Not used.
  195. * @param[in] out_activation_min minimum value to clamp the output to. Range : int8
  196. * @param[in] out_activation_max maximum value to clamp the output to. Range : int8
  197. * @param[in] row_len number of elements in each row
  198. * @param[in] bias per output channel bias. Range : int32
  199. * @param[in,out] out pointer to output
  200. * @return The function returns one of the two
  201. * 1. The incremented output pointer for a successful operation or
  202. * 2. NULL if implementation is not available.
  203. *
  204. * @details Supported framework: TensorFlow Lite
  205. */
  206. q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
  207. const q7_t *input_col,
  208. const uint16_t output_ch,
  209. const uint16_t col_batches,
  210. const int32_t *output_shift,
  211. const int32_t *output_mult,
  212. const int32_t out_offset,
  213. const int32_t col_offset,
  214. const int32_t row_offset,
  215. const int16_t out_activation_min,
  216. const int16_t out_activation_max,
  217. const uint16_t row_len,
  218. const int32_t *const bias,
  219. q7_t *out);
  220. /**
  221. * @brief General Matrix-multiplication without requantization for one row & one column
  222. * @param[in] row_elements number of row elements
  223. * @param[in] row_base pointer to row operand
  224. * @param[in] col_base pointer to col operand
  225. * @param[out] sum_col pointer to store sum of column elements
  226. * @param[out] output pointer to store result of multiply-accumulate
  227. * @return The function returns the multiply-accumulated result of the row by column.
  228. *
  229. * @details Pseudo-code
  230. * *output = 0
  231. * sum_col = 0
  232. * for (i = 0; i < row_elements; i++)
  233. * *output += row_base[i] * col_base[i]
  234. * sum_col += col_base[i]
  235. *
  236. */
  237. arm_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
  238. const int8_t *row_base,
  239. const int8_t *col_base,
  240. int32_t *const sum_col,
  241. int32_t *const output);
  242. /**
  243. * @brief General Matrix-multiplication without requantization for four rows and one column
  244. * @param[in] row_elements number of row elements
  245. * @param[in] offset offset between rows. Can be the same as row_elements.
  246. * For e.g, in a 1x1 conv scenario with stride as 1.
  247. * @param[in] row_base pointer to row operand
  248. * @param[in] col_base pointer to col operand
  249. * @param[out] sum_col pointer to store sum of column elements
  250. * @param[out] output pointer to store result(4 int32's) of multiply-accumulate
  251. * @return The function returns the multiply-accumulated result of the row by column
  252. *
  253. * @details Pseudo-code
  254. * output[0] = 0
  255. * ..
  256. * output[3] = 0
  257. * sum_col = 0
  258. * for (i = 0; i < row_elements; i++)
  259. * output[0] += row_base[i] * col_base[i]
  260. * ..
  261. * output[3] += row_base[i + (row_elements * 3)] * col_base[i]
  262. * sum_col += col_base[i]
  263. */
  264. arm_status arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
  265. const int32_t offset,
  266. const int8_t *row_base,
  267. const int8_t *col_base,
  268. int32_t *const sum_col,
  269. int32_t *const output);
  270. /**
  271. * @brief General Matrix-multiplication function with per-channel requantization.
  272. * This function assumes:
  273. * - LHS input matrix NOT transposed (nt)
  274. * - RHS input matrix transposed (t)
  275. *
  276. * @note This operation also performs the broadcast bias addition before the requantization
  277. *
  278. * @param[in] lhs Pointer to the LHS input matrix
  279. * @param[in] rhs Pointer to the RHS input matrix
  280. * @param[in] bias Pointer to the bias vector. The length of this vector is equal to the number of output columns (or RHS input rows)
  281. * @param[out] dst Pointer to the output matrix with "m" rows and "n" columns
  282. * @param[in] dst_multipliers Pointer to the multipliers vector needed for the per-channel requantization. The length of this vector is equal to
  283. * the number of output columns (or RHS input rows)
  284. * @param[in] dst_shifts Pointer to the shifts vector needed for the per-channel requantization. The length of this vector is equal to
  285. * the number of output columns (or RHS input rows)
  286. * @param[in] lhs_rows Number of LHS input rows
  287. * @param[in] rhs_rows Number of RHS input rows
  288. * @param[in] rhs_cols Number of LHS/RHS input columns
  289. * @param[in] lhs_offset Offset to be applied to the LHS input value
  290. * @param[in] dst_offset Offset to be applied the output result
  291. * @param[in] activation_min Minimum value to clamp down the output. Range : int8
  292. * @param[in] activation_max Maximum value to clamp up the output. Range : int8
  293. *
  294. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  295. *
  296. */
  297. arm_status arm_nn_mat_mult_nt_t_s8(const q7_t *lhs,
  298. const q7_t *rhs,
  299. const q31_t *bias,
  300. q7_t *dst,
  301. const int32_t *dst_multipliers,
  302. const int32_t *dst_shifts,
  303. const int32_t lhs_rows,
  304. const int32_t rhs_rows,
  305. const int32_t rhs_cols,
  306. const int32_t lhs_offset,
  307. const int32_t dst_offset,
  308. const int32_t activation_min,
  309. const int32_t activation_max);
  310. /**
  311. * @brief s8 Vector by Matrix (transposed) multiplication
  312. *
  313. * @param[in] lhs Input left-hand side vector
  314. * @param[in] rhs Input right-hand side matrix (transposed)
  315. * @param[in] bias Input bias
  316. * @param[out] dst Output vector
  317. * @param[in] lhs_offset Offset to be added to the input values of the left-hand side vector. Range: -127 to 128
  318. * @param[in] rhs_offset Offset to be added to the input values of the right-hand side matrix. Range: -127 to 128
  319. * @param[in] dst_offset Offset to be added to the output values. Range: -127 to 128
  320. * @param[in] dst_multiplier Output multiplier
  321. * @param[in] dst_shift Output shift
  322. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  323. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  324. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  325. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  326. *
  327. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  328. *
  329. */
  330. arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
  331. const q7_t *rhs,
  332. const q31_t *bias,
  333. q7_t *dst,
  334. const int32_t lhs_offset,
  335. const int32_t rhs_offset,
  336. const int32_t dst_offset,
  337. const int32_t dst_multiplier,
  338. const int32_t dst_shift,
  339. const int32_t rhs_cols,
  340. const int32_t rhs_rows,
  341. const int32_t activation_min,
  342. const int32_t activation_max);
  343. /**
  344. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
  345. * the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
  346. *
  347. * @param[in] lhs Input left-hand side matrix
  348. * @param[in] rhs Input right-hand side matrix (transposed)
  349. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  350. * @param[in] num_ch Number of channels in LHS/RHS
  351. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels
  352. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels
  353. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  354. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  355. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  356. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  357. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels
  358. * @param[in] out Output pointer
  359. *
  360. * @return The function returns one of the two
  361. * - Updated output pointer if an implementaiton is available
  362. * - NULL if no implementation is available.
  363. *
  364. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read out
  365. * for the following.
  366. * - Output shift
  367. * - Output multiplier
  368. * - Output bias
  369. * - rhs
  370. */
  371. q7_t *arm_nn_depthwise_conv_nt_t_padded_s8(const q7_t *lhs,
  372. const q7_t *rhs,
  373. const int32_t lhs_offset,
  374. const uint16_t num_ch,
  375. const int32_t *out_shift,
  376. const int32_t *out_mult,
  377. const int32_t out_offset,
  378. const int32_t activation_min,
  379. const int32_t activation_max,
  380. const uint16_t row_x_col,
  381. const int32_t *const output_bias,
  382. q7_t *out);
  383. /**
  384. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
  385. * Dimensions are the same for lhs and rhs.
  386. *
  387. * @param[in] lhs Input left-hand side matrix
  388. * @param[in] rhs Input right-hand side matrix (transposed)
  389. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  390. * @param[in] num_ch Number of channels in LHS/RHS
  391. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels.
  392. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels.
  393. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  394. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  395. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  396. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  397. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels.
  398. * @param[in] out Output pointer
  399. *
  400. * @return The function returns one of the two
  401. * - Updated output pointer if an implementaiton is available
  402. * - NULL if no implementation is available.
  403. *
  404. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read out
  405. * for the following.
  406. * - Output shift
  407. * - Output multiplier
  408. * - Output bias
  409. * - rhs
  410. */
  411. q7_t *arm_nn_depthwise_conv_nt_t_s8(const q7_t *lhs,
  412. const q7_t *rhs,
  413. const int32_t lhs_offset,
  414. const uint16_t num_ch,
  415. const int32_t *out_shift,
  416. const int32_t *out_mult,
  417. const int32_t out_offset,
  418. const int32_t activation_min,
  419. const int32_t activation_max,
  420. const uint16_t row_x_col,
  421. const int32_t *const output_bias,
  422. q7_t *out);
  423. /**
  424. @brief Read 2 q15 elements and post increment pointer.
  425. @param[in] in_q15 Pointer to pointer that holds address of input.
  426. @return q31 value
  427. */
  428. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2_ia(const q15_t **in_q15)
  429. {
  430. q31_t val;
  431. memcpy(&val, *in_q15, 4);
  432. *in_q15 += 2;
  433. return (val);
  434. }
  435. /**
  436. @brief Read 4 q7 from q7 pointer and post increment pointer.
  437. @param[in] in_q7 Pointer to pointer that holds address of input.
  438. @return q31 value
  439. */
  440. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4_ia(const q7_t **in_q7)
  441. {
  442. q31_t val;
  443. memcpy(&val, *in_q7, 4);
  444. *in_q7 += 4;
  445. return (val);
  446. }
  447. /**
  448. @brief Read 2 q15 from q15 pointer.
  449. @param[in] in_q15 pointer to address of input.
  450. @return q31 value
  451. */
  452. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2(const q15_t *in_q15)
  453. {
  454. q31_t val;
  455. memcpy(&val, in_q15, 4);
  456. return (val);
  457. }
  458. /**
  459. @brief Read 4 q7 values.
  460. @param[in] in_q7 pointer to address of input.
  461. @return q31 value
  462. */
  463. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4(const q7_t *in_q7)
  464. {
  465. q31_t val;
  466. memcpy(&val, in_q7, 4);
  467. return (val);
  468. }
  469. /**
  470. * @brief memset optimized for MVE
  471. * @param[in, out] dst Destination pointer
  472. * @param[in] val Value to set
  473. * @param[in] block_size Number of bytes to copy.
  474. *
  475. */
  476. __STATIC_FORCEINLINE void arm_memset_q7(q7_t *dst,
  477. const q7_t val,
  478. uint32_t block_size)
  479. {
  480. #if defined(ARM_MATH_MVEI)
  481. __asm volatile (
  482. " vdup.8 q0, %[set_val] \n"
  483. " wlstp.8 lr, %[cnt], 1f \n"
  484. "2: \n"
  485. " vstrb.8 q0, [%[in]], 16 \n"
  486. " letp lr, 2b \n"
  487. "1: \n"
  488. :[in] "+r"(dst)
  489. :[cnt] "r"(block_size), [set_val] "r"(val)
  490. :"q0", "memory", "r14");
  491. #else
  492. memset(dst, val, block_size);
  493. #endif
  494. }
  495. #if defined (ARM_MATH_DSP)
  496. /**
  497. * @brief read and expand one q7 word into two q15 words
  498. */
  499. __STATIC_FORCEINLINE const q7_t *read_and_pad(const q7_t *source, q31_t * out1, q31_t * out2)
  500. {
  501. q31_t inA = arm_nn_read_q7x4_ia(&source);
  502. q31_t inAbuf1 = __SXTB16(__ROR(inA, 8));
  503. q31_t inAbuf2 = __SXTB16(inA);
  504. #ifndef ARM_MATH_BIG_ENDIAN
  505. *out2 = __PKHTB(inAbuf1, inAbuf2, 16);
  506. *out1 = __PKHBT(inAbuf2, inAbuf1, 16);
  507. #else
  508. *out1 = __PKHTB(inAbuf1, inAbuf2, 16);
  509. *out2 = __PKHBT(inAbuf2, inAbuf1, 16);
  510. #endif
  511. return source;
  512. }
  513. /**
  514. * @brief read and expand one q7 word into two q15 words with reordering
  515. */
  516. __STATIC_FORCEINLINE const q7_t *read_and_pad_reordered(const q7_t *source, q31_t * out1, q31_t * out2)
  517. {
  518. q31_t inA = arm_nn_read_q7x4_ia(&source);
  519. #ifndef ARM_MATH_BIG_ENDIAN
  520. *out2 = __SXTB16(__ROR(inA, 8));
  521. *out1 = __SXTB16(inA);
  522. #else
  523. *out1 = __SXTB16(__ROR(inA, 8));
  524. *out2 = __SXTB16(inA);
  525. #endif
  526. return source;
  527. }
  528. /**
  529. * @brief read and expand one q7 word into two q15 words with reordering and add an offset
  530. */
  531. __STATIC_FORCEINLINE const q7_t *read_and_pad_reordered_with_offset(const q7_t *source, q31_t * out1, q31_t * out2, q31_t offset)
  532. {
  533. q31_t inA = arm_nn_read_q7x4_ia(&source);
  534. #ifndef ARM_MATH_BIG_ENDIAN
  535. *out2 = __SXTB16(__ROR(inA, 8));
  536. *out1 = __SXTB16(inA);
  537. #else
  538. *out1 = __SXTB16(__ROR(inA, 8));
  539. *out2 = __SXTB16(inA);
  540. #endif
  541. *out1 = __QADD16(*out1,offset);
  542. *out2 = __QADD16(*out2,offset);
  543. return source;
  544. }
  545. #endif
  546. /**
  547. * @defgroup NNBasicMath Basic Math Functions for Neural Network Computation
  548. *
  549. * Basic Math Functions for Neural Network Computation
  550. *
  551. */
  552. /**
  553. * @brief q7 vector multiplication with variable output shifts
  554. * @param[in] *pSrcA pointer to the first input vector
  555. * @param[in] *pSrcB pointer to the second input vector
  556. * @param[out] *pDst pointer to the output vector
  557. * @param[in] out_shift amount of right-shift for output
  558. * @param[in] blockSize number of samples in each vector
  559. * @return none.
  560. *
  561. * <b>Scaling and Overflow Behavior:</b>
  562. * \par
  563. * The function uses saturating arithmetic.
  564. * Results outside of the allowable q15 range [0x8000 0x7FFF] will be saturated.
  565. */
  566. void arm_nn_mult_q15(
  567. q15_t * pSrcA,
  568. q15_t * pSrcB,
  569. q15_t * pDst,
  570. const uint16_t out_shift,
  571. uint32_t blockSize);
  572. /**
  573. * @brief q7 vector multiplication with variable output shifts
  574. * @param[in] *pSrcA pointer to the first input vector
  575. * @param[in] *pSrcB pointer to the second input vector
  576. * @param[out] *pDst pointer to the output vector
  577. * @param[in] out_shift amount of right-shift for output
  578. * @param[in] blockSize number of samples in each vector
  579. * @return none.
  580. *
  581. * <b>Scaling and Overflow Behavior:</b>
  582. * \par
  583. * The function uses saturating arithmetic.
  584. * Results outside of the allowable q7 range [0x80 0x7F] will be saturated.
  585. */
  586. void arm_nn_mult_q7(
  587. q7_t * pSrcA,
  588. q7_t * pSrcB,
  589. q7_t * pDst,
  590. const uint16_t out_shift,
  591. uint32_t blockSize);
  592. /**
  593. * @brief macro for adding rounding offset
  594. */
  595. #ifndef ARM_NN_TRUNCATE
  596. #define NN_ROUND(out_shift) ( (0x1u << out_shift) >> 1 )
  597. #else
  598. #define NN_ROUND(out_shift) 0
  599. #endif
  600. // Macros for shortening quantization functions' names and avoid long lines
  601. #define MUL_SAT(a, b) arm_nn_sat_doubling_high_mult((a), (b))
  602. #define MUL_SAT_MVE(a, b) arm_sat_doubling_high_mult_mve_32x4((a), (b))
  603. #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
  604. #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
  605. #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
  606. #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
  607. #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
  608. /**
  609. * @brief Saturating doubling high multiply. Result matches
  610. * NEON instruction VQRDMULH.
  611. * @param[in] m1 Multiplicand
  612. * @param[in] m2 Multiplier
  613. * @return Result of multiplication.
  614. *
  615. */
  616. __STATIC_FORCEINLINE q31_t arm_nn_sat_doubling_high_mult(const q31_t m1, const q31_t m2)
  617. {
  618. q31_t result = 0;
  619. // Rounding offset to add for a right shift of 31
  620. q63_t mult = 1 << 30;
  621. if ((m1 < 0) ^ (m2 < 0))
  622. {
  623. mult = 1 - mult;
  624. }
  625. // Gets resolved as a SMLAL instruction
  626. mult = mult + (q63_t)m1 * m2;
  627. // Utilize all of the upper 32 bits. This is the doubling step
  628. // as well.
  629. result = mult / (1UL << 31);
  630. if ((m1 == m2) && (m1 == (int32_t)Q31_MIN))
  631. {
  632. result = Q31_MAX;
  633. }
  634. return result;
  635. }
  636. /**
  637. * @brief Rounding divide by power of two.
  638. * @param[in] dividend - Dividend
  639. * @param[in] exponent - Divisor = power(2, exponent)
  640. * Range: [0, 31]
  641. * @return Rounded result of division. Midpoint is rounded away from zero.
  642. *
  643. */
  644. __STATIC_FORCEINLINE q31_t arm_nn_divide_by_power_of_two(const q31_t dividend, const q31_t exponent)
  645. {
  646. q31_t result = 0;
  647. const q31_t remainder_mask = (1l << exponent) - 1;
  648. int32_t remainder = remainder_mask & dividend;
  649. // Basic division
  650. result = dividend >> exponent;
  651. // Adjust 'result' for rounding (mid point away from zero)
  652. q31_t threshold = remainder_mask >> 1;
  653. if (result < 0)
  654. {
  655. threshold++;
  656. }
  657. if (remainder > threshold)
  658. {
  659. result++;
  660. }
  661. return result;
  662. }
  663. /**
  664. * @brief Requantize a given value.
  665. * @param[in] val Value to be requantized
  666. * @param[in] multiplier multiplier
  667. * @param[in] shift left or right shift for 'val * multiplier'
  668. *
  669. * @return Returns (val * multiplier)/(2 ^ shift)
  670. *
  671. */
  672. __STATIC_FORCEINLINE q31_t arm_nn_requantize(const q31_t val, const q31_t multiplier, const q31_t shift)
  673. {
  674. return arm_nn_divide_by_power_of_two(arm_nn_sat_doubling_high_mult(val * (1 << LEFT_SHIFT(shift)), multiplier),
  675. RIGHT_SHIFT(shift));
  676. }
  677. /**
  678. * @brief memcpy optimized for MVE
  679. * @param[in, out] dst Destination pointer
  680. * @param[in] src Source pointer.
  681. * @param[in] block_size Number of bytes to copy.
  682. *
  683. */
  684. __STATIC_FORCEINLINE void arm_memcpy_q7(q7_t *__RESTRICT dst,
  685. const q7_t *__RESTRICT src,
  686. uint32_t block_size)
  687. {
  688. #if defined(ARM_MATH_MVEI)
  689. __asm volatile (
  690. " wlstp.8 lr, %[cnt], 1f \n"
  691. "2: \n"
  692. " vldrb.8 q0, [%[in]], 16 \n"
  693. " vstrb.8 q0, [%[out]], 16 \n"
  694. " letp lr, 2b \n"
  695. "1: \n"
  696. :[in] "+r"(src)
  697. ,[out] "+r"(dst)
  698. :[cnt] "r"(block_size)
  699. :"q0", "memory", "r14");
  700. #else
  701. memcpy(dst, src, block_size);
  702. #endif
  703. }
  704. #if defined(ARM_MATH_MVEI)
  705. /**
  706. * @brief Vector saturating doubling high multiply returning high half.
  707. * @param[in] m1 Multiplicand
  708. * @param[in] m2 Multiplier
  709. * @return Result of multiplication.
  710. *
  711. */
  712. __STATIC_FORCEINLINE int32x4_t arm_sat_doubling_high_mult_mve(const int32x4_t m1, const q31_t m2)
  713. {
  714. return vqrdmulhq_n_s32(m1, m2);
  715. }
  716. /**
  717. * @brief Vector rounding divide by power of two.
  718. * @param[in] dividend - Dividend vector
  719. * @param[in] exponent - Divisor = power(2, exponent)
  720. * Range: [0, 31]
  721. * @return Rounded result of division. Midpoint is rounded away from zero.
  722. *
  723. */
  724. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const q31_t exponent)
  725. {
  726. const int32x4_t shift = vdupq_n_s32(-exponent);
  727. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  728. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  729. return vrshlq_s32(fixed_up_dividend, shift);
  730. }
  731. /**
  732. * @brief Requantize a given vector.
  733. * @param[in] val Vector to be requantized
  734. * @param[in] multiplier multiplier
  735. * @param[in] shift shift
  736. *
  737. * @return Returns (val * multiplier)/(2 ^ shift)
  738. *
  739. */
  740. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const q31_t multiplier, const q31_t shift)
  741. {
  742. return arm_divide_by_power_of_two_mve(
  743. arm_sat_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier),
  744. RIGHT_SHIFT(shift));
  745. }
  746. __STATIC_FORCEINLINE int32x4_t arm_sat_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
  747. {
  748. return vqrdmulhq_s32(m1, m2);
  749. }
  750. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
  751. {
  752. const int32x4_t shift = -exponent;
  753. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  754. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  755. return vrshlq_s32(fixed_up_dividend, shift);
  756. }
  757. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val, const int32x4_t multiplier, const int32x4_t shift)
  758. {
  759. const int32x4_t zz = vdupq_n_s32(0);
  760. const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
  761. const int32x4_t left_shift = vpselq_s32(shift, zz, p);
  762. const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
  763. return arm_divide_by_power_of_two_mve_32x4(arm_sat_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier), right_shift);
  764. }
  765. #endif
  766. // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
  767. __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
  768. {
  769. int32_t mask = 0;
  770. int32_t shift = 24;
  771. const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
  772. const int32_t remainder = val_mod_minus_quarter - val;
  773. const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
  774. const int32_t x2 = MUL_SAT(x, x);
  775. int32_t result = 1895147668 + MUL_SAT(1895147668, x +
  776. DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
  777. #define SELECT_IF_NON_ZERO(x) \
  778. { \
  779. mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
  780. result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
  781. }
  782. SELECT_IF_NON_ZERO(1672461947)
  783. SELECT_IF_NON_ZERO(1302514674)
  784. SELECT_IF_NON_ZERO(790015084)
  785. SELECT_IF_NON_ZERO(290630308)
  786. SELECT_IF_NON_ZERO(39332535)
  787. SELECT_IF_NON_ZERO(720401)
  788. SELECT_IF_NON_ZERO(242)
  789. #undef SELECT_IF_NON_ZERO
  790. mask = MASK_IF_ZERO(val);
  791. return SELECT_USING_MASK(mask, Q31_MAX, result);
  792. }
  793. __STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
  794. {
  795. const int32_t thresh = ((1 << (31 - exp)) - 1);
  796. int32_t result = val << exp;
  797. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), Q31_MAX, result);
  798. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), Q31_MIN, result);
  799. return result;
  800. }
  801. __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
  802. {
  803. const int64_t sum = (int64_t)val + (int64_t)Q31_MAX;
  804. const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
  805. int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
  806. const int32_t shift = (1 << 29);
  807. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  808. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  809. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  810. return MUL_POW2(x, 1);
  811. }
  812. #ifdef __cplusplus
  813. }
  814. #endif
  815. #endif