arm_nnsupportfunctions.h 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nnsupportfunctions.h
  21. * Description: Public header file of support functions for CMSIS NN Library
  22. *
  23. * $Date: 15. April 2021
  24. * $Revision: V.5.5.0
  25. *
  26. * Target Processor: Cortex-M CPUs
  27. * -------------------------------------------------------------------- */
  28. #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
  29. #define _ARM_NNSUPPORTFUNCTIONS_H_
  30. #include "arm_common_tables.h"
  31. #include "arm_math_types.h"
  32. #ifdef __cplusplus
  33. extern "C" {
  34. #endif
  35. #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
  36. #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
  37. #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
  38. #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
  39. #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
  40. #define MAX(A, B) ((A) > (B) ? (A) : (B))
  41. #define MIN(A, B) ((A) < (B) ? (A) : (B))
  42. #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
  43. /**
  44. * @brief Union for SIMD access of q31/q15/q7 types
  45. */
  46. union arm_nnword
  47. {
  48. q31_t word;
  49. /**< q31 type */
  50. q15_t half_words[2];
  51. /**< q15 type */
  52. q7_t bytes[4];
  53. /**< q7 type */
  54. };
  55. /**
  56. * @brief Union for data type long long
  57. */
  58. struct arm_nn_double
  59. {
  60. uint32_t low;
  61. int32_t high;
  62. };
  63. union arm_nn_long_long
  64. {
  65. int64_t long_long;
  66. struct arm_nn_double word;
  67. };
  68. /**
  69. * @defgroup nndata_convert Neural Network Data Conversion Functions
  70. *
  71. * Perform data type conversion in-between neural network operations
  72. *
  73. */
  74. /**
  75. * @brief Converts the elements of the q7 vector to q15 vector without left-shift
  76. * @param[in] *pSrc points to the q7 input vector
  77. * @param[out] *pDst points to the q15 output vector
  78. * @param[in] blockSize length of the input vector
  79. *
  80. */
  81. void arm_q7_to_q15_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
  82. /**
  83. * @brief Non-saturating addition of elements of a q7 vector
  84. * @param[in] *input Pointer to the q7 input vector
  85. * @param[out] *output Pointer to the q31 output variable.
  86. * @param[in] block_size length of the input vector
  87. * \par Description:
  88. *
  89. * 2^24 samples can be added without saturating the result.
  90. *
  91. * The equation used for the conversion process is:
  92. *
  93. * <pre>
  94. * sum = input[0] + input[1] + .. + input[block_size -1]
  95. * </pre>
  96. *
  97. * */
  98. void arm_nn_add_q7(const q7_t *input, q31_t *output, uint32_t block_size);
  99. /**
  100. * @brief Converts the elements of the q7 vector to reordered q15 vector without left-shift
  101. * @param[in] *pSrc points to the q7 input vector
  102. * @param[out] *pDst points to the q15 output vector
  103. * @param[in] blockSize length of the input vector
  104. * @return none.
  105. *
  106. */
  107. void arm_q7_to_q15_reordered_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
  108. /**
  109. * @brief Converts the elements from a q7 vector to a q15 vector with an added offset
  110. * @param[in] src pointer to the q7 input vector
  111. * @param[out] dst pointer to the q15 output vector
  112. * @param[in] block_size length of the input vector
  113. * @param[in] offset q7 offset to be added to each input vector element.
  114. *
  115. * \par Description:
  116. *
  117. * The equation used for the conversion process is:
  118. *
  119. * <pre>
  120. * dst[n] = (q15_t) src[n] + offset; 0 <= n < block_size.
  121. * </pre>
  122. *
  123. */
  124. void arm_q7_to_q15_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  125. /**
  126. * @brief Converts the elements of the q7 vector to reordered q15 vector with an added offset
  127. * @param[in] src pointer to the q7 input vector
  128. * @param[out] dst pointer to the q15 output vector
  129. * @param[in] block_size length of the input vector
  130. * @param[in] offset offset to be added to each input vector element.
  131. * @return none.
  132. *
  133. * @details This function does the q7 to q15 expansion with re-ordering of bytes. Re-ordering is a consequence of
  134. * the sign extension intrinsic(DSP extension). The tail (i.e., last (N % 4) elements) retains its
  135. * original order.
  136. *
  137. */
  138. void arm_q7_to_q15_reordered_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  139. /**
  140. * @brief Converts the elements from a q7 vector and accumulate to a q15 vector
  141. * @param[in] *src points to the q7 input vector
  142. * @param[out] *dst points to the q15 output vector
  143. * @param[in] block_size length of the input vector
  144. *
  145. * \par Description:
  146. *
  147. * The equation used for the conversion process is:
  148. *
  149. * <pre>
  150. * dst[n] += (q15_t) src[n] ; 0 <= n < block_size.
  151. * </pre>
  152. *
  153. */
  154. void arm_nn_accumulate_q7_to_q15(q15_t *dst, const q7_t *src, uint32_t block_size);
  155. /**
  156. * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
  157. * @param[in] row pointer to row
  158. * @param[in] col pointer to im2col buffer, always consists of 2 columns.
  159. * @param[in] num_ch number of channels
  160. * @param[in] out_shift pointer to per output channel requantization shift parameter.
  161. * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
  162. * @param[in] out_offset output tensor offset.
  163. * @param[in] activation_min minimum value to clamp the output to. Range : int8
  164. * @param[in] activation_max maximum value to clamp the output to. Range : int8
  165. * @param[in] kernel_size number of elements in one column.
  166. * @param[in] output_bias per output channel bias. Range : int32
  167. * @param[out] out pointer to output
  168. * @return The function returns one of the two
  169. * 1. The incremented output pointer for a successful operation or
  170. * 2. NULL if implementation is not available.
  171. *
  172. * @details Supported framework: TensorFlow Lite micro.
  173. */
  174. q7_t *arm_nn_depthwise_conv_s8_core(const q7_t *row,
  175. const q15_t *col,
  176. const uint16_t num_ch,
  177. const int32_t *out_shift,
  178. const int32_t *out_mult,
  179. const int32_t out_offset,
  180. const int32_t activation_min,
  181. const int32_t activation_max,
  182. const uint16_t kernel_size,
  183. const int32_t *const output_bias,
  184. q7_t *out);
  185. /**
  186. * @brief General Matrix-multiplication function with per-channel requantization.
  187. * @param[in] input_row pointer to row operand
  188. * @param[in] input_col pointer to col operand
  189. * @param[in] output_ch number of rows of input_row
  190. * @param[in] col_batches number of column batches. Range: 1 to 4
  191. * @param[in] output_shift pointer to per output channel requantization shift parameter.
  192. * @param[in] output_mult pointer to per output channel requantization multiplier parameter.
  193. * @param[in] out_offset output tensor offset.
  194. * @param[in] col_offset input tensor(col) offset.
  195. * @param[in] row_offset kernel offset(row). Not used.
  196. * @param[in] out_activation_min minimum value to clamp the output to. Range : int8
  197. * @param[in] out_activation_max maximum value to clamp the output to. Range : int8
  198. * @param[in] row_len number of elements in each row
  199. * @param[in] bias per output channel bias. Range : int32
  200. * @param[in,out] out pointer to output
  201. * @return The function returns one of the two
  202. * 1. The incremented output pointer for a successful operation or
  203. * 2. NULL if implementation is not available.
  204. *
  205. * @details Supported framework: TensorFlow Lite
  206. */
  207. q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
  208. const q7_t *input_col,
  209. const uint16_t output_ch,
  210. const uint16_t col_batches,
  211. const int32_t *output_shift,
  212. const int32_t *output_mult,
  213. const int32_t out_offset,
  214. const int32_t col_offset,
  215. const int32_t row_offset,
  216. const int16_t out_activation_min,
  217. const int16_t out_activation_max,
  218. const uint16_t row_len,
  219. const int32_t *const bias,
  220. q7_t *out);
  221. /**
  222. * @brief General Matrix-multiplication without requantization for one row & one column
  223. * @param[in] row_elements number of row elements
  224. * @param[in] row_base pointer to row operand
  225. * @param[in] col_base pointer to col operand
  226. * @param[out] sum_col pointer to store sum of column elements
  227. * @param[out] output pointer to store result of multiply-accumulate
  228. * @return The function returns the multiply-accumulated result of the row by column.
  229. *
  230. * @details Pseudo-code
  231. * *output = 0
  232. * sum_col = 0
  233. * for (i = 0; i < row_elements; i++)
  234. * *output += row_base[i] * col_base[i]
  235. * sum_col += col_base[i]
  236. *
  237. */
  238. arm_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
  239. const int8_t *row_base,
  240. const int8_t *col_base,
  241. int32_t *const sum_col,
  242. int32_t *const output);
  243. /**
  244. * @brief General Matrix-multiplication without requantization for four rows and one column
  245. * @param[in] row_elements number of row elements
  246. * @param[in] offset offset between rows. Can be the same as row_elements.
  247. * For e.g, in a 1x1 conv scenario with stride as 1.
  248. * @param[in] row_base pointer to row operand
  249. * @param[in] col_base pointer to col operand
  250. * @param[out] sum_col pointer to store sum of column elements
  251. * @param[out] output pointer to store result(4 int32's) of multiply-accumulate
  252. * @return The function returns the multiply-accumulated result of the row by column
  253. *
  254. * @details Pseudo-code
  255. * output[0] = 0
  256. * ..
  257. * output[3] = 0
  258. * sum_col = 0
  259. * for (i = 0; i < row_elements; i++)
  260. * output[0] += row_base[i] * col_base[i]
  261. * ..
  262. * output[3] += row_base[i + (row_elements * 3)] * col_base[i]
  263. * sum_col += col_base[i]
  264. */
  265. arm_status arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
  266. const int32_t offset,
  267. const int8_t *row_base,
  268. const int8_t *col_base,
  269. int32_t *const sum_col,
  270. int32_t *const output);
  271. /**
  272. * @brief General Matrix-multiplication function with per-channel requantization.
  273. * This function assumes:
  274. * - LHS input matrix NOT transposed (nt)
  275. * - RHS input matrix transposed (t)
  276. *
  277. * @note This operation also performs the broadcast bias addition before the requantization
  278. *
  279. * @param[in] lhs Pointer to the LHS input matrix
  280. * @param[in] rhs Pointer to the RHS input matrix
  281. * @param[in] bias Pointer to the bias vector. The length of this vector is equal to the number of
  282. * output columns (or RHS input rows)
  283. * @param[out] dst Pointer to the output matrix with "m" rows and "n" columns
  284. * @param[in] dst_multipliers Pointer to the multipliers vector needed for the per-channel requantization.
  285. * The length of this vector is equal to the number of output columns (or RHS input
  286. * rows)
  287. * @param[in] dst_shifts Pointer to the shifts vector needed for the per-channel requantization. The length
  288. * of this vector is equal to the number of output columns (or RHS input rows)
  289. * @param[in] lhs_rows Number of LHS input rows
  290. * @param[in] rhs_rows Number of RHS input rows
  291. * @param[in] rhs_cols Number of LHS/RHS input columns
  292. * @param[in] lhs_offset Offset to be applied to the LHS input value
  293. * @param[in] dst_offset Offset to be applied the output result
  294. * @param[in] activation_min Minimum value to clamp down the output. Range : int8
  295. * @param[in] activation_max Maximum value to clamp up the output. Range : int8
  296. *
  297. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  298. *
  299. */
  300. arm_status arm_nn_mat_mult_nt_t_s8(const q7_t *lhs,
  301. const q7_t *rhs,
  302. const q31_t *bias,
  303. q7_t *dst,
  304. const int32_t *dst_multipliers,
  305. const int32_t *dst_shifts,
  306. const int32_t lhs_rows,
  307. const int32_t rhs_rows,
  308. const int32_t rhs_cols,
  309. const int32_t lhs_offset,
  310. const int32_t dst_offset,
  311. const int32_t activation_min,
  312. const int32_t activation_max);
  313. /**
  314. * @brief s8 Vector by Matrix (transposed) multiplication
  315. *
  316. * @param[in] lhs Input left-hand side vector
  317. * @param[in] rhs Input right-hand side matrix (transposed)
  318. * @param[in] bias Input bias
  319. * @param[out] dst Output vector
  320. * @param[in] lhs_offset Offset to be added to the input values of the left-hand side vector.
  321. * Range: -127 to 128
  322. * @param[in] rhs_offset Not used
  323. * @param[in] dst_offset Offset to be added to the output values. Range: -127 to 128
  324. * @param[in] dst_multiplier Output multiplier
  325. * @param[in] dst_shift Output shift
  326. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  327. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  328. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  329. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  330. *
  331. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  332. *
  333. */
  334. arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
  335. const q7_t *rhs,
  336. const q31_t *bias,
  337. q7_t *dst,
  338. const int32_t lhs_offset,
  339. const int32_t rhs_offset,
  340. const int32_t dst_offset,
  341. const int32_t dst_multiplier,
  342. const int32_t dst_shift,
  343. const int32_t rhs_cols,
  344. const int32_t rhs_rows,
  345. const int32_t activation_min,
  346. const int32_t activation_max);
  347. /**
  348. * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
  349. *
  350. * @param[in] lhs Input left-hand side vector
  351. * @param[in] rhs Input right-hand side matrix (transposed)
  352. * @param[out] dst Output vector
  353. * @param[in] lhs_offset Offset to be added to the input values of the left-hand side
  354. * vector. Range: -127 to 128
  355. * @param[in] rhs_offset Not used
  356. * @param[in] scatter_offset Address offset for dst. First output is stored at 'dst', the
  357. * second at 'dst + scatter_offset' and so on.
  358. * @param[in] dst_multiplier Output multiplier
  359. * @param[in] dst_shift Output shift
  360. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  361. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  362. * @param[in] activation_min Minimum value to clamp the output to. Range: int16
  363. * @param[in] activation_max Maximum value to clamp the output to. Range: int16
  364. *
  365. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  366. *
  367. */
  368. arm_status arm_nn_vec_mat_mult_t_svdf_s8(const q7_t *lhs,
  369. const q7_t *rhs,
  370. q15_t *dst,
  371. const int32_t lhs_offset,
  372. const int32_t rhs_offset,
  373. const int32_t scatter_offset,
  374. const int32_t dst_multiplier,
  375. const int32_t dst_shift,
  376. const int32_t rhs_cols,
  377. const int32_t rhs_rows,
  378. const int32_t activation_min,
  379. const int32_t activation_max);
  380. /**
  381. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
  382. * the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
  383. *
  384. * @param[in] lhs Input left-hand side matrix
  385. * @param[in] rhs Input right-hand side matrix (transposed)
  386. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  387. * @param[in] num_ch Number of channels in LHS/RHS
  388. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels
  389. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels
  390. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  391. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  392. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  393. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  394. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels
  395. * @param[in] out Output pointer
  396. *
  397. * @return The function returns one of the two
  398. * - Updated output pointer if an implementation is available
  399. * - NULL if no implementation is available.
  400. *
  401. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
  402. * out for the following.
  403. * - Output shift
  404. * - Output multiplier
  405. * - Output bias
  406. * - rhs
  407. */
  408. q7_t *arm_nn_depthwise_conv_nt_t_padded_s8(const q7_t *lhs,
  409. const q7_t *rhs,
  410. const int32_t lhs_offset,
  411. const uint16_t num_ch,
  412. const int32_t *out_shift,
  413. const int32_t *out_mult,
  414. const int32_t out_offset,
  415. const int32_t activation_min,
  416. const int32_t activation_max,
  417. const uint16_t row_x_col,
  418. const int32_t *const output_bias,
  419. q7_t *out);
  420. /**
  421. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
  422. * Dimensions are the same for lhs and rhs.
  423. *
  424. * @param[in] lhs Input left-hand side matrix
  425. * @param[in] rhs Input right-hand side matrix (transposed)
  426. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  427. * @param[in] num_ch Number of channels in LHS/RHS
  428. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels.
  429. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels.
  430. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  431. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  432. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  433. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  434. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels.
  435. * @param[in] out Output pointer
  436. *
  437. * @return The function returns one of the two
  438. * - Updated output pointer if an implementation is available
  439. * - NULL if no implementation is available.
  440. *
  441. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
  442. * out for the following.
  443. * - Output shift
  444. * - Output multiplier
  445. * - Output bias
  446. * - rhs
  447. */
  448. q7_t *arm_nn_depthwise_conv_nt_t_s8(const q7_t *lhs,
  449. const q7_t *rhs,
  450. const int32_t lhs_offset,
  451. const uint16_t num_ch,
  452. const int32_t *out_shift,
  453. const int32_t *out_mult,
  454. const int32_t out_offset,
  455. const int32_t activation_min,
  456. const int32_t activation_max,
  457. const uint16_t row_x_col,
  458. const int32_t *const output_bias,
  459. q7_t *out);
  460. /**
  461. @brief Read 2 q15 elements and post increment pointer.
  462. @param[in] in_q15 Pointer to pointer that holds address of input.
  463. @return q31 value
  464. */
  465. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2_ia(const q15_t **in_q15)
  466. {
  467. q31_t val;
  468. memcpy(&val, *in_q15, 4);
  469. *in_q15 += 2;
  470. return (val);
  471. }
  472. /**
  473. @brief Read 4 q7 from q7 pointer and post increment pointer.
  474. @param[in] in_q7 Pointer to pointer that holds address of input.
  475. @return q31 value
  476. */
  477. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4_ia(const q7_t **in_q7)
  478. {
  479. q31_t val;
  480. memcpy(&val, *in_q7, 4);
  481. *in_q7 += 4;
  482. return (val);
  483. }
  484. /**
  485. @brief Read 2 q15 from q15 pointer.
  486. @param[in] in_q15 pointer to address of input.
  487. @return q31 value
  488. */
  489. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2(const q15_t *in_q15)
  490. {
  491. q31_t val;
  492. memcpy(&val, in_q15, 4);
  493. return (val);
  494. }
  495. /**
  496. @brief Read 4 q7 values.
  497. @param[in] in_q7 pointer to address of input.
  498. @return q31 value
  499. */
  500. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4(const q7_t *in_q7)
  501. {
  502. q31_t val;
  503. memcpy(&val, in_q7, 4);
  504. return (val);
  505. }
  506. /**
  507. * @brief memset optimized for MVE
  508. * @param[in, out] dst Destination pointer
  509. * @param[in] val Value to set
  510. * @param[in] block_size Number of bytes to copy.
  511. *
  512. */
  513. __STATIC_FORCEINLINE void arm_memset_q7(q7_t *dst, const q7_t val, uint32_t block_size)
  514. {
  515. #if defined(ARM_MATH_MVEI)
  516. __asm volatile(" vdup.8 q0, %[set_val] \n"
  517. " wlstp.8 lr, %[cnt], 1f \n"
  518. "2: \n"
  519. " vstrb.8 q0, [%[in]], 16 \n"
  520. " letp lr, 2b \n"
  521. "1: \n"
  522. : [ in ] "+r"(dst)
  523. : [ cnt ] "r"(block_size), [ set_val ] "r"(val)
  524. : "q0", "memory", "r14");
  525. #else
  526. memset(dst, val, block_size);
  527. #endif
  528. }
  529. #if defined(ARM_MATH_DSP)
  530. /**
  531. * @brief read and expand one q7 word into two q15 words
  532. */
  533. __STATIC_FORCEINLINE const q7_t *read_and_pad(const q7_t *source, q31_t *out1, q31_t *out2)
  534. {
  535. q31_t inA = arm_nn_read_q7x4_ia(&source);
  536. q31_t inAbuf1 = __SXTB16(__ROR((uint32_t)inA, 8));
  537. q31_t inAbuf2 = __SXTB16(inA);
  538. #ifndef ARM_MATH_BIG_ENDIAN
  539. *out2 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
  540. *out1 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
  541. #else
  542. *out1 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
  543. *out2 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
  544. #endif
  545. return source;
  546. }
  547. /**
  548. * @brief read and expand one q7 word into two q15 words with reordering
  549. */
  550. __STATIC_FORCEINLINE const q7_t *read_and_pad_reordered(const q7_t *source, q31_t *out1, q31_t *out2)
  551. {
  552. q31_t inA = arm_nn_read_q7x4_ia(&source);
  553. #ifndef ARM_MATH_BIG_ENDIAN
  554. *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
  555. *out1 = __SXTB16(inA);
  556. #else
  557. *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
  558. *out2 = __SXTB16(inA);
  559. #endif
  560. return source;
  561. }
  562. /**
  563. * @brief read and expand one q7 word into two q15 words with reordering and add an offset
  564. */
  565. __STATIC_FORCEINLINE const q7_t *
  566. read_and_pad_reordered_with_offset(const q7_t *source, q31_t *out1, q31_t *out2, q31_t offset)
  567. {
  568. q31_t inA = arm_nn_read_q7x4_ia(&source);
  569. #ifndef ARM_MATH_BIG_ENDIAN
  570. *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
  571. *out1 = __SXTB16(inA);
  572. #else
  573. *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
  574. *out2 = __SXTB16(inA);
  575. #endif
  576. *out1 = __QADD16(*out1, offset);
  577. *out2 = __QADD16(*out2, offset);
  578. return source;
  579. }
  580. #endif
  581. /**
  582. * @defgroup NNBasicMath Basic Math Functions for Neural Network Computation
  583. *
  584. * Basic Math Functions for Neural Network Computation
  585. *
  586. */
  587. /**
  588. * @brief q7 vector multiplication with variable output shifts
  589. * @param[in] *pSrcA pointer to the first input vector
  590. * @param[in] *pSrcB pointer to the second input vector
  591. * @param[out] *pDst pointer to the output vector
  592. * @param[in] out_shift amount of right-shift for output
  593. * @param[in] blockSize number of samples in each vector
  594. * @return none.
  595. *
  596. * <b>Scaling and Overflow Behavior:</b>
  597. * \par
  598. * The function uses saturating arithmetic.
  599. * Results outside of the allowable q15 range [0x8000 0x7FFF] will be saturated.
  600. */
  601. void arm_nn_mult_q15(q15_t *pSrcA, q15_t *pSrcB, q15_t *pDst, const uint16_t out_shift, uint32_t blockSize);
  602. /**
  603. * @brief q7 vector multiplication with variable output shifts
  604. * @param[in] *pSrcA pointer to the first input vector
  605. * @param[in] *pSrcB pointer to the second input vector
  606. * @param[out] *pDst pointer to the output vector
  607. * @param[in] out_shift amount of right-shift for output
  608. * @param[in] blockSize number of samples in each vector
  609. * @return none.
  610. *
  611. * <b>Scaling and Overflow Behavior:</b>
  612. * \par
  613. * The function uses saturating arithmetic.
  614. * Results outside of the allowable q7 range [0x80 0x7F] will be saturated.
  615. */
  616. void arm_nn_mult_q7(q7_t *pSrcA, q7_t *pSrcB, q7_t *pDst, const uint16_t out_shift, uint32_t blockSize);
  617. /**
  618. * @brief macro for adding rounding offset
  619. */
  620. #ifndef ARM_NN_TRUNCATE
  621. #define NN_ROUND(out_shift) ((0x1u << out_shift) >> 1)
  622. #else
  623. #define NN_ROUND(out_shift) 0
  624. #endif
  625. // Macros for shortening quantization functions' names and avoid long lines
  626. #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
  627. #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
  628. #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
  629. #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
  630. #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
  631. #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
  632. #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
  633. /**
  634. * @brief Saturating doubling high multiply. Result matches
  635. * NEON instruction VQRDMULH.
  636. * @param[in] m1 Multiplicand. Range: {Q31_MIN, Q31_MAX}
  637. * @param[in] m2 Multiplier. Range: {Q31_MIN, Q31_MAX}
  638. * @return Result of multiplication.
  639. *
  640. */
  641. __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult(const q31_t m1, const q31_t m2)
  642. {
  643. q31_t result = 0;
  644. // Rounding offset to add for a right shift of 31
  645. q63_t mult = 1 << 30;
  646. if ((m1 < 0) ^ (m2 < 0))
  647. {
  648. mult = 1 - mult;
  649. }
  650. // Gets resolved as a SMLAL instruction
  651. mult = mult + (q63_t)m1 * m2;
  652. // Utilize all of the upper 32 bits. This is the doubling step
  653. // as well.
  654. result = (int32_t)(mult / (1ll << 31));
  655. if ((m1 == m2) && (m1 == (int32_t)Q31_MIN))
  656. {
  657. result = Q31_MAX;
  658. }
  659. return result;
  660. }
  661. /**
  662. * @brief Doubling high multiply without saturation. This is intended
  663. * for requantization where the scale is a positive integer
  664. *
  665. * @param[in] m1 Multiplicand. Range: {Q31_MIN, Q31_MAX}
  666. * @param[in] m2 Multiplier Range: {Q31_MIN, Q31_MAX}
  667. * @return Result of multiplication.
  668. * @note The result of this matches that of neon instruction
  669. * VQRDMULH for m1 in range {Q31_MIN, Q31_MAX} and m2 in
  670. * range {Q31_MIN + 1, Q31_MAX}. Saturation occurs when
  671. * m1 equals m2 equals Q31_MIN and that is not handled by
  672. * this function.
  673. *
  674. */
  675. __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult_no_sat(const q31_t m1, const q31_t m2)
  676. {
  677. q31_t result = 0;
  678. union arm_nn_long_long mult;
  679. // Rounding offset to add for a right shift of 31
  680. mult.word.low = 1 << 30;
  681. mult.word.high = 0;
  682. // Gets resolved as a SMLAL instruction
  683. mult.long_long = mult.long_long + (q63_t)m1 * m2;
  684. // Utilize all of the upper 32 bits. This is the doubling step
  685. // as well.
  686. result = (int32_t)(mult.long_long >> 31);
  687. return result;
  688. }
  689. /**
  690. * @brief Rounding divide by power of two.
  691. * @param[in] dividend - Dividend
  692. * @param[in] exponent - Divisor = power(2, exponent)
  693. * Range: [0, 31]
  694. * @return Rounded result of division. Midpoint is rounded away from zero.
  695. *
  696. */
  697. __STATIC_FORCEINLINE q31_t arm_nn_divide_by_power_of_two(const q31_t dividend, const q31_t exponent)
  698. {
  699. q31_t result = 0;
  700. const q31_t remainder_mask = (1 << exponent) - 1;
  701. int32_t remainder = remainder_mask & dividend;
  702. // Basic division
  703. result = dividend >> exponent;
  704. // Adjust 'result' for rounding (mid point away from zero)
  705. q31_t threshold = remainder_mask >> 1;
  706. if (result < 0)
  707. {
  708. threshold++;
  709. }
  710. if (remainder > threshold)
  711. {
  712. result++;
  713. }
  714. return result;
  715. }
  716. /**
  717. * @brief Requantize a given value.
  718. * @param[in] val Value to be requantized
  719. * @param[in] multiplier multiplier. Range {Q31_MIN + 1, Q32_MAX}
  720. * @param[in] shift left or right shift for 'val * multiplier'
  721. *
  722. * @return Returns (val * multiplier)/(2 ^ shift)
  723. *
  724. */
  725. __STATIC_FORCEINLINE q31_t arm_nn_requantize(const q31_t val, const q31_t multiplier, const q31_t shift)
  726. {
  727. return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
  728. RIGHT_SHIFT(shift));
  729. }
  730. /**
  731. * @brief memcpy optimized for MVE
  732. * @param[in, out] dst Destination pointer
  733. * @param[in] src Source pointer.
  734. * @param[in] block_size Number of bytes to copy.
  735. *
  736. */
  737. __STATIC_FORCEINLINE void arm_memcpy_q7(q7_t *__RESTRICT dst, const q7_t *__RESTRICT src, uint32_t block_size)
  738. {
  739. #if defined(ARM_MATH_MVEI)
  740. __asm volatile(" wlstp.8 lr, %[cnt], 1f \n"
  741. "2: \n"
  742. " vldrb.8 q0, [%[in]], 16 \n"
  743. " vstrb.8 q0, [%[out]], 16 \n"
  744. " letp lr, 2b \n"
  745. "1: \n"
  746. : [ in ] "+r"(src), [ out ] "+r"(dst)
  747. : [ cnt ] "r"(block_size)
  748. : "q0", "memory", "r14");
  749. #else
  750. memcpy(dst, src, block_size);
  751. #endif
  752. }
  753. #if defined(ARM_MATH_MVEI)
  754. /**
  755. * @brief Vector saturating doubling high multiply returning high half.
  756. * @param[in] m1 Multiplicand
  757. * @param[in] m2 Multiplier
  758. * @return Result of multiplication.
  759. *
  760. */
  761. __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const q31_t m2)
  762. {
  763. return vqrdmulhq_n_s32(m1, m2);
  764. }
  765. /**
  766. * @brief Vector rounding divide by power of two.
  767. * @param[in] dividend - Dividend vector
  768. * @param[in] exponent - Divisor = power(2, exponent)
  769. * Range: [0, 31]
  770. * @return Rounded result of division. Midpoint is rounded away from zero.
  771. *
  772. */
  773. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const q31_t exponent)
  774. {
  775. const int32x4_t shift = vdupq_n_s32(-exponent);
  776. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  777. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  778. return vrshlq_s32(fixed_up_dividend, shift);
  779. }
  780. /**
  781. * @brief Requantize a given vector.
  782. * @param[in] val Vector to be requantized
  783. * @param[in] multiplier multiplier
  784. * @param[in] shift shift
  785. *
  786. * @return Returns (val * multiplier)/(2 ^ shift)
  787. *
  788. */
  789. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const q31_t multiplier, const q31_t shift)
  790. {
  791. return arm_divide_by_power_of_two_mve(
  792. arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
  793. }
  794. __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
  795. {
  796. return vqrdmulhq_s32(m1, m2);
  797. }
  798. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
  799. {
  800. const int32x4_t shift = -exponent;
  801. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  802. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  803. return vrshlq_s32(fixed_up_dividend, shift);
  804. }
  805. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
  806. const int32x4_t multiplier,
  807. const int32x4_t shift)
  808. {
  809. const int32x4_t zz = vdupq_n_s32(0);
  810. const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
  811. const int32x4_t left_shift = vpselq_s32(shift, zz, p);
  812. const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
  813. return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
  814. right_shift);
  815. }
  816. #endif
  817. // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
  818. __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
  819. {
  820. int32_t mask = 0;
  821. int32_t shift = 24;
  822. const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
  823. const int32_t remainder = val_mod_minus_quarter - val;
  824. const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
  825. const int32_t x2 = MUL_SAT(x, x);
  826. int32_t result = 1895147668 +
  827. MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
  828. #define SELECT_IF_NON_ZERO(x) \
  829. { \
  830. mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
  831. result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
  832. }
  833. SELECT_IF_NON_ZERO(1672461947)
  834. SELECT_IF_NON_ZERO(1302514674)
  835. SELECT_IF_NON_ZERO(790015084)
  836. SELECT_IF_NON_ZERO(290630308)
  837. SELECT_IF_NON_ZERO(39332535)
  838. SELECT_IF_NON_ZERO(720401)
  839. SELECT_IF_NON_ZERO(242)
  840. #undef SELECT_IF_NON_ZERO
  841. mask = MASK_IF_ZERO(val);
  842. return SELECT_USING_MASK(mask, Q31_MAX, result);
  843. }
  844. __STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
  845. {
  846. const int32_t thresh = ((1 << (31 - exp)) - 1);
  847. int32_t result = val << exp;
  848. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), Q31_MAX, result);
  849. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), Q31_MIN, result);
  850. return result;
  851. }
  852. __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
  853. {
  854. const int64_t sum = (int64_t)val + (int64_t)Q31_MAX;
  855. const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
  856. int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
  857. const int32_t shift = (1 << 29);
  858. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  859. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  860. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  861. return MUL_POW2(x, 1);
  862. }
  863. /**
  864. @brief Write 2 q15 elements and post increment pointer.
  865. @param[in] dest_q15 Pointer to pointer that holds address of destination.
  866. @param[in] src_q31 Input value to be written.
  867. @return none
  868. */
  869. __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(q15_t **dest_q15, q31_t src_q31)
  870. {
  871. q31_t val = src_q31;
  872. memcpy(*dest_q15, &val, 4);
  873. *dest_q15 += 2;
  874. }
  875. #ifdef __cplusplus
  876. }
  877. #endif
  878. #endif