arm_nnsupportfunctions.h 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184
  1. /*
  2. * Copyright (C) 2010-2022 Arm Limited or its affiliates.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nnsupportfunctions.h
  21. * Description: Public header file of support functions for CMSIS NN Library
  22. *
  23. * $Date: 16. March 2022
  24. * $Revision: V.6.2.1
  25. *
  26. * Target Processor: Cortex-M CPUs
  27. * -------------------------------------------------------------------- */
  28. #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
  29. #define _ARM_NNSUPPORTFUNCTIONS_H_
  30. #include "arm_nn_math_types.h"
  31. #include "arm_nn_types.h"
  32. #include <stdbool.h>
  33. #ifdef __cplusplus
  34. extern "C" {
  35. #endif
  36. #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
  37. #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
  38. #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
  39. #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
  40. #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
  41. #define MAX(A, B) ((A) > (B) ? (A) : (B))
  42. #define MIN(A, B) ((A) < (B) ? (A) : (B))
  43. #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
  44. #define REDUCE_MULTIPLIER(_mult) ((_mult < 0x7FFF0000) ? ((_mult + (1 << 15)) >> 16) : 0x7FFF)
  45. /**
  46. * @brief definition to pack four 8 bit values.
  47. */
  48. #define PACK_Q7x4_32x1(v0, v1, v2, v3) \
  49. ((((int32_t)(v0) << 0) & (int32_t)0x000000FF) | (((int32_t)(v1) << 8) & (int32_t)0x0000FF00) | \
  50. (((int32_t)(v2) << 16) & (int32_t)0x00FF0000) | (((int32_t)(v3) << 24) & (int32_t)0xFF000000))
  51. /**
  52. * @brief Union for SIMD access of q31/q15/q7 types
  53. */
  54. union arm_nnword
  55. {
  56. q31_t word;
  57. /**< q31 type */
  58. q15_t half_words[2];
  59. /**< q15 type */
  60. q7_t bytes[4];
  61. /**< q7 type */
  62. };
  63. /**
  64. * @brief Union for data type long long
  65. */
  66. struct arm_nn_double
  67. {
  68. uint32_t low;
  69. int32_t high;
  70. };
  71. union arm_nn_long_long
  72. {
  73. int64_t long_long;
  74. struct arm_nn_double word;
  75. };
  76. /**
  77. * @defgroup nndata_convert Neural Network Data Conversion Functions
  78. *
  79. * Perform data type conversion in-between neural network operations
  80. *
  81. */
  82. /**
  83. * @brief Converts the elements of the q7 vector to q15 vector without left-shift
  84. * @param[in] *pSrc points to the q7 input vector
  85. * @param[out] *pDst points to the q15 output vector
  86. * @param[in] blockSize length of the input vector
  87. *
  88. */
  89. void arm_q7_to_q15_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
  90. /**
  91. * @brief Non-saturating addition of elements of a q7 vector
  92. * @param[in] *input Pointer to the q7 input vector
  93. * @param[out] *output Pointer to the q31 output variable.
  94. * @param[in] block_size length of the input vector
  95. * \par Description:
  96. *
  97. * 2^24 samples can be added without saturating the result.
  98. *
  99. * The equation used for the conversion process is:
  100. *
  101. * <pre>
  102. * sum = input[0] + input[1] + .. + input[block_size -1]
  103. * </pre>
  104. *
  105. * */
  106. void arm_nn_add_q7(const q7_t *input, q31_t *output, uint32_t block_size);
  107. /**
  108. * @brief Converts the elements of the q7 vector to reordered q15 vector without left-shift
  109. * @param[in] *pSrc points to the q7 input vector
  110. * @param[out] *pDst points to the q15 output vector
  111. * @param[in] blockSize length of the input vector
  112. * @return none.
  113. *
  114. */
  115. void arm_q7_to_q15_reordered_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
  116. /**
  117. * @brief Converts the elements from a q7 vector to a q15 vector with an added offset
  118. * @param[in] src pointer to the q7 input vector
  119. * @param[out] dst pointer to the q15 output vector
  120. * @param[in] block_size length of the input vector
  121. * @param[in] offset q7 offset to be added to each input vector element.
  122. *
  123. * \par Description:
  124. *
  125. * The equation used for the conversion process is:
  126. *
  127. * <pre>
  128. * dst[n] = (q15_t) src[n] + offset; 0 <= n < block_size.
  129. * </pre>
  130. *
  131. */
  132. void arm_q7_to_q15_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  133. /**
  134. * @brief Converts the elements of the q7 vector to reordered q15 vector with an added offset
  135. * @param[in] src pointer to the q7 input vector
  136. * @param[out] dst pointer to the q15 output vector
  137. * @param[in] block_size length of the input vector
  138. * @param[in] offset offset to be added to each input vector element.
  139. * @return none.
  140. *
  141. * @details This function does the q7 to q15 expansion with re-ordering of bytes. Re-ordering is a consequence of
  142. * the sign extension intrinsic(DSP extension). The tail (i.e., last (N % 4) elements) retains its
  143. * original order.
  144. *
  145. */
  146. void arm_q7_to_q15_reordered_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
  147. /**
  148. * @brief Converts the elements from a q7 vector and accumulate to a q15 vector
  149. * @param[in] *src points to the q7 input vector
  150. * @param[out] *dst points to the q15 output vector
  151. * @param[in] block_size length of the input vector
  152. *
  153. * \par Description:
  154. *
  155. * The equation used for the conversion process is:
  156. *
  157. * <pre>
  158. * dst[n] += (q15_t) src[n] ; 0 <= n < block_size.
  159. * </pre>
  160. *
  161. */
  162. void arm_nn_accumulate_q7_to_q15(q15_t *dst, const q7_t *src, uint32_t block_size);
  163. /**
  164. * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
  165. * @param[in] row pointer to row
  166. * @param[in] col pointer to im2col buffer, always consists of 2 columns.
  167. * @param[in] num_ch number of channels
  168. * @param[in] out_shift pointer to per output channel requantization shift parameter.
  169. * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
  170. * @param[in] out_offset output tensor offset.
  171. * @param[in] activation_min minimum value to clamp the output to. Range : int8
  172. * @param[in] activation_max maximum value to clamp the output to. Range : int8
  173. * @param[in] kernel_size number of elements in one column.
  174. * @param[in] output_bias per output channel bias. Range : int32
  175. * @param[out] out pointer to output
  176. * @return The function returns one of the two
  177. * 1. The incremented output pointer for a successful operation or
  178. * 2. NULL if implementation is not available.
  179. *
  180. * @details Supported framework: TensorFlow Lite micro.
  181. */
  182. q7_t *arm_nn_depthwise_conv_s8_core(const q7_t *row,
  183. const q15_t *col,
  184. const uint16_t num_ch,
  185. const int32_t *out_shift,
  186. const int32_t *out_mult,
  187. const int32_t out_offset,
  188. const int32_t activation_min,
  189. const int32_t activation_max,
  190. const uint16_t kernel_size,
  191. const int32_t *const output_bias,
  192. q7_t *out);
  193. /**
  194. * @brief General Matrix-multiplication function with per-channel requantization.
  195. * @param[in] input_row pointer to row operand
  196. * @param[in] input_col pointer to col operand
  197. * @param[in] output_ch number of rows of input_row
  198. * @param[in] col_batches number of column batches. Range: 1 to 4
  199. * @param[in] output_shift pointer to per output channel requantization shift parameter.
  200. * @param[in] output_mult pointer to per output channel requantization multiplier parameter.
  201. * @param[in] out_offset output tensor offset.
  202. * @param[in] col_offset input tensor(col) offset.
  203. * @param[in] row_offset kernel offset(row). Not used.
  204. * @param[in] out_activation_min minimum value to clamp the output to. Range : int8
  205. * @param[in] out_activation_max maximum value to clamp the output to. Range : int8
  206. * @param[in] row_len number of elements in each row
  207. * @param[in] bias per output channel bias. Range : int32
  208. * @param[in,out] out pointer to output
  209. * @return The function returns one of the two
  210. * 1. The incremented output pointer for a successful operation or
  211. * 2. NULL if implementation is not available.
  212. *
  213. * @details Supported framework: TensorFlow Lite
  214. */
  215. q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
  216. const q7_t *input_col,
  217. const uint16_t output_ch,
  218. const uint16_t col_batches,
  219. const int32_t *output_shift,
  220. const int32_t *output_mult,
  221. const int32_t out_offset,
  222. const int32_t col_offset,
  223. const int32_t row_offset,
  224. const int16_t out_activation_min,
  225. const int16_t out_activation_max,
  226. const uint16_t row_len,
  227. const int32_t *const bias,
  228. q7_t *out);
  229. /**
  230. * @brief Matrix-multiplication function for convolution with per-channel requantization for 16 bits convolution.
  231. * @param[in] input_a pointer to operand A
  232. * @param[in] input_b pointer to operand B, always consists of 2 vectors.
  233. * @param[in] output_ch number of rows of A
  234. * @param[in] out_shift pointer to per output channel requantization shift parameter.
  235. * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
  236. * @param[in] activation_min minimum value to clamp the output to. Range : int16
  237. * @param[in] activation_max maximum value to clamp the output to. Range : int16
  238. * @param[in] num_col_a number of columns of A
  239. * @param[in] output_bias per output channel bias. Range : int64
  240. * @param[in,out] out_0 pointer to output
  241. * @return The function returns one of the two
  242. * 1. The incremented output pointer for a successful operation or
  243. * 2. NULL if implementation is not available.
  244. *
  245. * @details This function does the matrix multiplication of weight matrix for all output channels
  246. * with 2 columns from im2col and produces two elements/output_channel. The outputs are
  247. * clamped in the range provided by activation min and max.
  248. * Supported framework: TensorFlow Lite micro.
  249. */
  250. q15_t *arm_nn_mat_mult_kernel_s16(const q7_t *input_a,
  251. const q15_t *input_b,
  252. const int32_t output_ch,
  253. const int32_t *out_shift,
  254. const int32_t *out_mult,
  255. const int16_t activation_min,
  256. const int16_t activation_max,
  257. const int32_t num_col_a,
  258. const int64_t *const output_bias,
  259. q15_t *out_0);
  260. /**
  261. * @brief General Matrix-multiplication without requantization for one row & one column
  262. * @param[in] row_elements number of row elements
  263. * @param[in] row_base pointer to row operand
  264. * @param[in] col_base pointer to col operand
  265. * @param[out] sum_col pointer to store sum of column elements
  266. * @param[out] output pointer to store result of multiply-accumulate
  267. * @return The function returns the multiply-accumulated result of the row by column.
  268. *
  269. * @details Pseudo-code
  270. * *output = 0
  271. * sum_col = 0
  272. * for (i = 0; i < row_elements; i++)
  273. * *output += row_base[i] * col_base[i]
  274. * sum_col += col_base[i]
  275. *
  276. */
  277. arm_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
  278. const int8_t *row_base,
  279. const int8_t *col_base,
  280. int32_t *const sum_col,
  281. int32_t *const output);
  282. /**
  283. * @brief Matrix-multiplication with requantization & activation function for four rows and one column
  284. * @param[in] row_elements number of row elements
  285. * @param[in] offset offset between rows. Can be the same as row_elements.
  286. * For e.g, in a 1x1 conv scenario with stride as 1.
  287. * @param[in] row_base pointer to row operand
  288. * @param[in] col_base pointer to col operand
  289. * @param[in] out_ch Number of output channels
  290. * @param[in] conv_params Pointer to convolution parameters like offsets and activation values
  291. * @param[in] quant_params Pointer to per-channel quantization parameters
  292. * @param[in] bias Pointer to per-channel bias
  293. * @param[out] output Pointer to output where int8 results are stored.
  294. *
  295. * @return The function returns the updated output pointer or NULL if implementation is not available.
  296. *
  297. * @details Compliant to TFLM int8 specification. MVE implementation only
  298. */
  299. int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
  300. const int32_t offset,
  301. const int8_t *row_base,
  302. const int8_t *col_base,
  303. const int32_t out_ch,
  304. const cmsis_nn_conv_params *conv_params,
  305. const cmsis_nn_per_channel_quant_params *quant_params,
  306. const int32_t *bias,
  307. int8_t *output);
  308. /**
  309. * @brief General Matrix-multiplication function with per-channel requantization.
  310. * This function assumes:
  311. * - LHS input matrix NOT transposed (nt)
  312. * - RHS input matrix transposed (t)
  313. *
  314. * @note This operation also performs the broadcast bias addition before the requantization
  315. *
  316. * @param[in] lhs Pointer to the LHS input matrix
  317. * @param[in] rhs Pointer to the RHS input matrix
  318. * @param[in] bias Pointer to the bias vector. The length of this vector is equal to the number of
  319. * output columns (or RHS input rows)
  320. * @param[out] dst Pointer to the output matrix with "m" rows and "n" columns
  321. * @param[in] dst_multipliers Pointer to the multipliers vector needed for the per-channel requantization.
  322. * The length of this vector is equal to the number of output columns (or RHS input
  323. * rows)
  324. * @param[in] dst_shifts Pointer to the shifts vector needed for the per-channel requantization. The length
  325. * of this vector is equal to the number of output columns (or RHS input rows)
  326. * @param[in] lhs_rows Number of LHS input rows
  327. * @param[in] rhs_rows Number of RHS input rows
  328. * @param[in] rhs_cols Number of LHS/RHS input columns
  329. * @param[in] lhs_offset Offset to be applied to the LHS input value
  330. * @param[in] dst_offset Offset to be applied the output result
  331. * @param[in] activation_min Minimum value to clamp down the output. Range : int8
  332. * @param[in] activation_max Maximum value to clamp up the output. Range : int8
  333. *
  334. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  335. *
  336. */
  337. arm_status arm_nn_mat_mult_nt_t_s8(const q7_t *lhs,
  338. const q7_t *rhs,
  339. const q31_t *bias,
  340. q7_t *dst,
  341. const int32_t *dst_multipliers,
  342. const int32_t *dst_shifts,
  343. const int32_t lhs_rows,
  344. const int32_t rhs_rows,
  345. const int32_t rhs_cols,
  346. const int32_t lhs_offset,
  347. const int32_t dst_offset,
  348. const int32_t activation_min,
  349. const int32_t activation_max);
  350. /**
  351. * @brief s8 Vector by Matrix (transposed) multiplication
  352. *
  353. * @param[in] lhs Input left-hand side vector
  354. * @param[in] rhs Input right-hand side matrix (transposed)
  355. * @param[in] bias Input bias
  356. * @param[out] dst Output vector
  357. * @param[in] lhs_offset Offset to be added to the input values of the left-hand side vector.
  358. * Range: -127 to 128
  359. * @param[in] rhs_offset Not used
  360. * @param[in] dst_offset Offset to be added to the output values. Range: -127 to 128
  361. * @param[in] dst_multiplier Output multiplier
  362. * @param[in] dst_shift Output shift
  363. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  364. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  365. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  366. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  367. *
  368. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  369. *
  370. */
  371. arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
  372. const q7_t *rhs,
  373. const q31_t *bias,
  374. q7_t *dst,
  375. const int32_t lhs_offset,
  376. const int32_t rhs_offset,
  377. const int32_t dst_offset,
  378. const int32_t dst_multiplier,
  379. const int32_t dst_shift,
  380. const int32_t rhs_cols,
  381. const int32_t rhs_rows,
  382. const int32_t activation_min,
  383. const int32_t activation_max);
  384. /**
  385. * @brief s16 Vector by Matrix (transposed) multiplication
  386. *
  387. * @param[in] lhs Input left-hand side vector
  388. * @param[in] rhs Input right-hand side matrix (transposed)
  389. * @param[in] bias Input bias
  390. * @param[out] dst Output vector
  391. * @param[in] dst_multiplier Output multiplier
  392. * @param[in] dst_shift Output shift
  393. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  394. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  395. * @param[in] activation_min Minimum value to clamp the output to. Range: int16
  396. * @param[in] activation_max Maximum value to clamp the output to. Range: int16
  397. *
  398. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  399. *
  400. */
  401. arm_status arm_nn_vec_mat_mult_t_s16(const q15_t *lhs,
  402. const q7_t *rhs,
  403. const q63_t *bias,
  404. q15_t *dst,
  405. const int32_t dst_multiplier,
  406. const int32_t dst_shift,
  407. const int32_t rhs_cols,
  408. const int32_t rhs_rows,
  409. const int32_t activation_min,
  410. const int32_t activation_max);
  411. /**
  412. * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
  413. *
  414. * @param[in] lhs Input left-hand side vector
  415. * @param[in] rhs Input right-hand side matrix (transposed)
  416. * @param[out] dst Output vector
  417. * @param[in] lhs_offset Offset to be added to the input values of the left-hand side
  418. * vector. Range: -127 to 128
  419. * @param[in] rhs_offset Not used
  420. * @param[in] scatter_offset Address offset for dst. First output is stored at 'dst', the
  421. * second at 'dst + scatter_offset' and so on.
  422. * @param[in] dst_multiplier Output multiplier
  423. * @param[in] dst_shift Output shift
  424. * @param[in] rhs_cols Number of columns in the right-hand side input matrix
  425. * @param[in] rhs_rows Number of rows in the right-hand side input matrix
  426. * @param[in] activation_min Minimum value to clamp the output to. Range: int16
  427. * @param[in] activation_max Maximum value to clamp the output to. Range: int16
  428. *
  429. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  430. *
  431. */
  432. arm_status arm_nn_vec_mat_mult_t_svdf_s8(const q7_t *lhs,
  433. const q7_t *rhs,
  434. q15_t *dst,
  435. const int32_t lhs_offset,
  436. const int32_t rhs_offset,
  437. const int32_t scatter_offset,
  438. const int32_t dst_multiplier,
  439. const int32_t dst_shift,
  440. const int32_t rhs_cols,
  441. const int32_t rhs_rows,
  442. const int32_t activation_min,
  443. const int32_t activation_max);
  444. /**
  445. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
  446. * the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
  447. *
  448. * @param[in] lhs Input left-hand side matrix
  449. * @param[in] rhs Input right-hand side matrix (transposed)
  450. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  451. * @param[in] num_ch Number of channels in LHS/RHS
  452. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels
  453. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels
  454. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  455. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  456. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  457. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  458. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels
  459. * @param[in] out Output pointer
  460. *
  461. * @return The function returns one of the two
  462. * - Updated output pointer if an implementation is available
  463. * - NULL if no implementation is available.
  464. *
  465. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
  466. * out for the following.
  467. * - Output shift
  468. * - Output multiplier
  469. * - Output bias
  470. * - rhs
  471. */
  472. q7_t *arm_nn_depthwise_conv_nt_t_padded_s8(const q7_t *lhs,
  473. const q7_t *rhs,
  474. const int32_t lhs_offset,
  475. const uint16_t num_ch,
  476. const int32_t *out_shift,
  477. const int32_t *out_mult,
  478. const int32_t out_offset,
  479. const int32_t activation_min,
  480. const int32_t activation_max,
  481. const uint16_t row_x_col,
  482. const int32_t *const output_bias,
  483. q7_t *out);
  484. /**
  485. * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
  486. * Dimensions are the same for lhs and rhs.
  487. *
  488. * @param[in] lhs Input left-hand side matrix
  489. * @param[in] rhs Input right-hand side matrix (transposed)
  490. * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
  491. * @param[in] num_ch Number of channels in LHS/RHS
  492. * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels.
  493. * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels.
  494. * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
  495. * @param[in] activation_min Minimum value to clamp the output to. Range: int8
  496. * @param[in] activation_max Maximum value to clamp the output to. Range: int8
  497. * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
  498. * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels.
  499. * @param[in] out Output pointer
  500. *
  501. * @return The function returns one of the two
  502. * - Updated output pointer if an implementation is available
  503. * - NULL if no implementation is available.
  504. *
  505. * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
  506. * out for the following.
  507. * - Output shift
  508. * - Output multiplier
  509. * - Output bias
  510. * - rhs
  511. */
  512. q7_t *arm_nn_depthwise_conv_nt_t_s8(const q7_t *lhs,
  513. const q7_t *rhs,
  514. const int32_t lhs_offset,
  515. const uint16_t num_ch,
  516. const int32_t *out_shift,
  517. const int32_t *out_mult,
  518. const int32_t out_offset,
  519. const int32_t activation_min,
  520. const int32_t activation_max,
  521. const uint16_t row_x_col,
  522. const int32_t *const output_bias,
  523. q7_t *out);
  524. /**
  525. *@brief Matrix-multiplication function for convolution with reordered columns
  526. *@param[in] pA pointer to operand A
  527. *@param[in] pInBuffer pointer to operand B, always conssists of 2 vectors
  528. *@param[in] ch_im_out numRow of A
  529. *@param[in] numCol_A numCol of A
  530. *@param[in] bias_shift amount of left-shift for bias
  531. *@param[in] out_shift amount of right-shift for output
  532. *@param[in] bias the bias
  533. *@param[in,out] pOut pointer to output
  534. *@return The function returns the incremented output pointer
  535. *
  536. *@details This function assumes that data in pInBuffer are reordered
  537. */
  538. q7_t *arm_nn_mat_mult_kernel_q7_q15_reordered(const q7_t *pA,
  539. const q15_t *pInBuffer,
  540. const uint16_t ch_im_out,
  541. const uint16_t numCol_A,
  542. const uint16_t bias_shift,
  543. const uint16_t out_shift,
  544. const q7_t *bias,
  545. q7_t *pOut);
  546. /**
  547. @brief Read 2 q15 elements and post increment pointer.
  548. @param[in] in_q15 Pointer to pointer that holds address of input.
  549. @return q31 value
  550. */
  551. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2_ia(const q15_t **in_q15)
  552. {
  553. q31_t val;
  554. memcpy(&val, *in_q15, 4);
  555. *in_q15 += 2;
  556. return (val);
  557. }
  558. /**
  559. @brief Read 4 q7 from q7 pointer and post increment pointer.
  560. @param[in] in_q7 Pointer to pointer that holds address of input.
  561. @return q31 value
  562. */
  563. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4_ia(const q7_t **in_q7)
  564. {
  565. q31_t val;
  566. memcpy(&val, *in_q7, 4);
  567. *in_q7 += 4;
  568. return (val);
  569. }
  570. /**
  571. @brief Read 2 q15 from q15 pointer.
  572. @param[in] in_q15 pointer to address of input.
  573. @return q31 value
  574. */
  575. __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2(const q15_t *in_q15)
  576. {
  577. q31_t val;
  578. memcpy(&val, in_q15, 4);
  579. return (val);
  580. }
  581. /**
  582. @brief Read 4 q7 values.
  583. @param[in] in_q7 pointer to address of input.
  584. @return q31 value
  585. */
  586. __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4(const q7_t *in_q7)
  587. {
  588. q31_t val;
  589. memcpy(&val, in_q7, 4);
  590. return (val);
  591. }
  592. /**
  593. @brief Write four q7 to q7 pointer and increment pointer afterwards.
  594. @param[in] in Double pointer to input value
  595. @param[in] value Four bytes to copy
  596. */
  597. __STATIC_FORCEINLINE void arm_nn_write_q7x4_ia(q7_t **in, q31_t value)
  598. {
  599. memcpy(*in, &value, 4);
  600. *in += 4;
  601. }
  602. /**
  603. * @brief memset optimized for MVE
  604. * @param[in, out] dst Destination pointer
  605. * @param[in] val Value to set
  606. * @param[in] block_size Number of bytes to copy.
  607. *
  608. */
  609. __STATIC_FORCEINLINE void arm_memset_q7(q7_t *dst, const q7_t val, uint32_t block_size)
  610. {
  611. #if defined(ARM_MATH_MVEI)
  612. __asm volatile(" vdup.8 q0, %[set_val] \n"
  613. " wlstp.8 lr, %[cnt], 1f \n"
  614. "2: \n"
  615. " vstrb.8 q0, [%[in]], 16 \n"
  616. " letp lr, 2b \n"
  617. "1: \n"
  618. : [ in ] "+r"(dst)
  619. : [ cnt ] "r"(block_size), [ set_val ] "r"(val)
  620. : "q0", "memory", "r14");
  621. #else
  622. memset(dst, val, block_size);
  623. #endif
  624. }
  625. #if defined(ARM_MATH_DSP)
  626. /**
  627. * @brief read and expand one q7 word into two q15 words
  628. */
  629. __STATIC_FORCEINLINE const q7_t *read_and_pad(const q7_t *source, q31_t *out1, q31_t *out2)
  630. {
  631. q31_t inA = arm_nn_read_q7x4_ia(&source);
  632. q31_t inAbuf1 = __SXTB16_RORn((uint32_t)inA, 8);
  633. q31_t inAbuf2 = __SXTB16(inA);
  634. #ifndef ARM_MATH_BIG_ENDIAN
  635. *out2 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
  636. *out1 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
  637. #else
  638. *out1 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
  639. *out2 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
  640. #endif
  641. return source;
  642. }
  643. /**
  644. * @brief read and expand one q7 word into two q15 words with reordering
  645. */
  646. __STATIC_FORCEINLINE const q7_t *read_and_pad_reordered(const q7_t *source, q31_t *out1, q31_t *out2)
  647. {
  648. q31_t inA = arm_nn_read_q7x4_ia(&source);
  649. #ifndef ARM_MATH_BIG_ENDIAN
  650. *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
  651. *out1 = __SXTB16(inA);
  652. #else
  653. *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
  654. *out2 = __SXTB16(inA);
  655. #endif
  656. return source;
  657. }
  658. /**
  659. * @brief read and expand one q7 word into two q15 words with reordering and add an offset
  660. */
  661. __STATIC_FORCEINLINE const q7_t *
  662. read_and_pad_reordered_with_offset(const q7_t *source, q31_t *out1, q31_t *out2, q31_t offset)
  663. {
  664. q31_t inA = arm_nn_read_q7x4_ia(&source);
  665. #ifndef ARM_MATH_BIG_ENDIAN
  666. *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
  667. *out1 = __SXTB16(inA);
  668. #else
  669. *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
  670. *out2 = __SXTB16(inA);
  671. #endif
  672. *out1 = __QADD16(*out1, offset);
  673. *out2 = __QADD16(*out2, offset);
  674. return source;
  675. }
  676. #endif
  677. /**
  678. * @defgroup NNBasicMath Basic Math Functions for Neural Network Computation
  679. *
  680. * Basic Math Functions for Neural Network Computation
  681. *
  682. */
  683. /**
  684. * @brief q7 vector multiplication with variable output shifts
  685. * @param[in] *pSrcA pointer to the first input vector
  686. * @param[in] *pSrcB pointer to the second input vector
  687. * @param[out] *pDst pointer to the output vector
  688. * @param[in] out_shift amount of right-shift for output
  689. * @param[in] blockSize number of samples in each vector
  690. * @return none.
  691. *
  692. * <b>Scaling and Overflow Behavior:</b>
  693. * \par
  694. * The function uses saturating arithmetic.
  695. * Results outside of the allowable q15 range [0x8000 0x7FFF] will be saturated.
  696. */
  697. void arm_nn_mult_q15(q15_t *pSrcA, q15_t *pSrcB, q15_t *pDst, const uint16_t out_shift, uint32_t blockSize);
  698. /**
  699. * @brief q7 vector multiplication with variable output shifts
  700. * @param[in] *pSrcA pointer to the first input vector
  701. * @param[in] *pSrcB pointer to the second input vector
  702. * @param[out] *pDst pointer to the output vector
  703. * @param[in] out_shift amount of right-shift for output
  704. * @param[in] blockSize number of samples in each vector
  705. * @return none.
  706. *
  707. * <b>Scaling and Overflow Behavior:</b>
  708. * \par
  709. * The function uses saturating arithmetic.
  710. * Results outside of the allowable q7 range [0x80 0x7F] will be saturated.
  711. */
  712. void arm_nn_mult_q7(q7_t *pSrcA, q7_t *pSrcB, q7_t *pDst, const uint16_t out_shift, uint32_t blockSize);
  713. /**
  714. * @brief Matrix-multiplication function for convolution with per-channel requantization.
  715. * @param[in] input_a pointer to operand A
  716. * @param[in] input_b pointer to operand B, always consists of 2 vectors.
  717. * @param[in] output_ch number of rows of A
  718. * @param[in] out_shift pointer to per output channel requantization shift parameter.
  719. * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
  720. * @param[in] out_offset output tensor offset.
  721. * @param[in] activation_min minimum value to clamp the output to. Range : int8
  722. * @param[in] activation_max maximum value to clamp the output to. Range : int8
  723. * @param[in] num_col_a number of columns of A
  724. * @param[in] output_bias per output channel bias. Range : int32
  725. * @param[in,out] out_0 pointer to output
  726. * @return The function returns one of the two
  727. * 1. The incremented output pointer for a successful operation or
  728. * 2. NULL if implementation is not available.
  729. *
  730. * @details This function does the matrix multiplication of weight matrix for all output channels
  731. * with 2 columns from im2col and produces two elements/output_channel. The outputs are
  732. * clamped in the range provided by activation min and max.
  733. * Supported framework: TensorFlow Lite micro.
  734. */
  735. q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
  736. const q15_t *input_b,
  737. const uint16_t output_ch,
  738. const int32_t *out_shift,
  739. const int32_t *out_mult,
  740. const int32_t out_offset,
  741. const int16_t activation_min,
  742. const int16_t activation_max,
  743. const uint16_t num_col_a,
  744. const int32_t *const output_bias,
  745. q7_t *out_0);
  746. /**
  747. * @brief Common softmax function for s8 input and s8 or s16 output
  748. * @param[in] input Pointer to the input tensor
  749. * @param[in] num_rows Number of rows in the input tensor
  750. * @param[in] row_size Number of elements in each input row
  751. * @param[in] mult Input quantization multiplier
  752. * @param[in] shift Input quantization shift within the range [0, 31]
  753. * @param[in] diff_min Minimum difference with max in row. Used to check if
  754. * the quantized exponential operation can be performed
  755. * @param[in] int16_output Indicating s8 output if 0 else s16 output
  756. * @param[out] output Pointer to the output tensor
  757. *
  758. * @note Supported framework: TensorFlow Lite micro (bit-accurate)
  759. *
  760. */
  761. void arm_nn_softmax_common_s8(const int8_t *input,
  762. const int32_t num_rows,
  763. const int32_t row_size,
  764. const int32_t mult,
  765. const int32_t shift,
  766. const int32_t diff_min,
  767. const bool int16_output,
  768. void *output);
  769. /**
  770. * @brief macro for adding rounding offset
  771. */
  772. #ifndef ARM_NN_TRUNCATE
  773. #define NN_ROUND(out_shift) ((0x1 << out_shift) >> 1)
  774. #else
  775. #define NN_ROUND(out_shift) 0
  776. #endif
  777. // Macros for shortening quantization functions' names and avoid long lines
  778. #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
  779. #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
  780. #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
  781. #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
  782. #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
  783. #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
  784. #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
  785. /**
  786. * @brief Saturating doubling high multiply. Result matches
  787. * NEON instruction VQRDMULH.
  788. * @param[in] m1 Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
  789. * @param[in] m2 Multiplier. Range: {NN_Q31_MIN, NN_Q31_MAX}
  790. * @return Result of multiplication.
  791. *
  792. */
  793. __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult(const q31_t m1, const q31_t m2)
  794. {
  795. q31_t result = 0;
  796. // Rounding offset to add for a right shift of 31
  797. q63_t mult = 1 << 30;
  798. if ((m1 < 0) ^ (m2 < 0))
  799. {
  800. mult = 1 - mult;
  801. }
  802. // Gets resolved as a SMLAL instruction
  803. mult = mult + (q63_t)m1 * m2;
  804. // Utilize all of the upper 32 bits. This is the doubling step
  805. // as well.
  806. result = (int32_t)(mult / (1ll << 31));
  807. if ((m1 == m2) && (m1 == (int32_t)NN_Q31_MIN))
  808. {
  809. result = NN_Q31_MAX;
  810. }
  811. return result;
  812. }
  813. /**
  814. * @brief Doubling high multiply without saturation. This is intended
  815. * for requantization where the scale is a positive integer
  816. *
  817. * @param[in] m1 Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
  818. * @param[in] m2 Multiplier Range: {NN_Q31_MIN, NN_Q31_MAX}
  819. * @return Result of multiplication.
  820. * @note The result of this matches that of neon instruction
  821. * VQRDMULH for m1 in range {NN_Q31_MIN, NN_Q31_MAX} and m2 in
  822. * range {NN_Q31_MIN + 1, NN_Q31_MAX}. Saturation occurs when
  823. * m1 equals m2 equals NN_Q31_MIN and that is not handled by
  824. * this function.
  825. *
  826. */
  827. __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult_no_sat(const q31_t m1, const q31_t m2)
  828. {
  829. q31_t result = 0;
  830. union arm_nn_long_long mult;
  831. // Rounding offset to add for a right shift of 31
  832. mult.word.low = 1 << 30;
  833. mult.word.high = 0;
  834. // Gets resolved as a SMLAL instruction
  835. mult.long_long = mult.long_long + (q63_t)m1 * m2;
  836. // Utilize all of the upper 32 bits. This is the doubling step
  837. // as well.
  838. result = (int32_t)(mult.long_long >> 31);
  839. return result;
  840. }
  841. /**
  842. * @brief Rounding divide by power of two.
  843. * @param[in] dividend - Dividend
  844. * @param[in] exponent - Divisor = power(2, exponent)
  845. * Range: [0, 31]
  846. * @return Rounded result of division. Midpoint is rounded away from zero.
  847. *
  848. */
  849. __STATIC_FORCEINLINE q31_t arm_nn_divide_by_power_of_two(const q31_t dividend, const q31_t exponent)
  850. {
  851. q31_t result = 0;
  852. const q31_t remainder_mask = (1 << exponent) - 1;
  853. int32_t remainder = remainder_mask & dividend;
  854. // Basic division
  855. result = dividend >> exponent;
  856. // Adjust 'result' for rounding (mid point away from zero)
  857. q31_t threshold = remainder_mask >> 1;
  858. if (result < 0)
  859. {
  860. threshold++;
  861. }
  862. if (remainder > threshold)
  863. {
  864. result++;
  865. }
  866. return result;
  867. }
  868. /**
  869. * @brief Requantize a given value.
  870. * @param[in] val Value to be requantized
  871. * @param[in] multiplier multiplier. Range {NN_Q31_MIN + 1, Q32_MAX}
  872. * @param[in] shift left or right shift for 'val * multiplier'
  873. *
  874. * @return Returns (val * multiplier)/(2 ^ shift)
  875. *
  876. */
  877. __STATIC_FORCEINLINE q31_t arm_nn_requantize(const q31_t val, const q31_t multiplier, const q31_t shift)
  878. {
  879. #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
  880. const int64_t total_shift = 31 - shift;
  881. const int64_t new_val = val * (int64_t)multiplier;
  882. int32_t result = new_val >> (total_shift - 1);
  883. result = (result + 1) >> 1;
  884. return result;
  885. #else
  886. return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
  887. RIGHT_SHIFT(shift));
  888. #endif
  889. }
  890. /**
  891. * @brief Requantize a given 64 bit value.
  892. * @param[in] val Value to be requantized in the range {-(1<<47)} to {(1<<47) - 1}
  893. * @param[in] reduced_multiplier Reduced multiplier in the range {NN_Q31_MIN + 1, Q32_MAX} to {Q16_MIN + 1,
  894. * Q16_MAX}
  895. * @param[in] shift Left or right shift for 'val * multiplier' in the range {-31} to {7}
  896. *
  897. * @return Returns (val * multiplier)/(2 ^ shift)
  898. *
  899. */
  900. __STATIC_FORCEINLINE q31_t arm_nn_requantize_s64(const q63_t val, const q31_t reduced_multiplier, const q31_t shift)
  901. {
  902. const q63_t new_val = val * reduced_multiplier;
  903. q31_t result = new_val >> (14 - shift); // 64->32 bit reduction
  904. result = (result + 1) >> 1; // Last shift position and insert round
  905. return result;
  906. }
  907. /**
  908. * @brief memcpy optimized for MVE
  909. * @param[in, out] dst Destination pointer
  910. * @param[in] src Source pointer.
  911. * @param[in] block_size Number of bytes to copy.
  912. *
  913. */
  914. __STATIC_FORCEINLINE void arm_memcpy_q7(q7_t *__RESTRICT dst, const q7_t *__RESTRICT src, uint32_t block_size)
  915. {
  916. #if defined(ARM_MATH_MVEI)
  917. __asm volatile(" wlstp.8 lr, %[cnt], 1f \n"
  918. "2: \n"
  919. " vldrb.8 q0, [%[in]], 16 \n"
  920. " vstrb.8 q0, [%[out]], 16 \n"
  921. " letp lr, 2b \n"
  922. "1: \n"
  923. : [ in ] "+r"(src), [ out ] "+r"(dst)
  924. : [ cnt ] "r"(block_size)
  925. : "q0", "memory", "r14");
  926. #else
  927. memcpy(dst, src, block_size);
  928. #endif
  929. }
  930. #if defined(ARM_MATH_MVEI)
  931. /**
  932. * @brief Vector saturating doubling high multiply returning high half.
  933. * @param[in] m1 Multiplicand
  934. * @param[in] m2 Multiplier
  935. * @return Result of multiplication.
  936. *
  937. */
  938. __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const q31_t m2)
  939. {
  940. return vqrdmulhq_n_s32(m1, m2);
  941. }
  942. /**
  943. * @brief Vector rounding divide by power of two.
  944. * @param[in] dividend - Dividend vector
  945. * @param[in] exponent - Divisor = power(2, exponent)
  946. * Range: [0, 31]
  947. * @return Rounded result of division. Midpoint is rounded away from zero.
  948. *
  949. */
  950. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const q31_t exponent)
  951. {
  952. const int32x4_t shift = vdupq_n_s32(-exponent);
  953. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  954. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  955. return vrshlq_s32(fixed_up_dividend, shift);
  956. }
  957. /**
  958. * @brief Requantize a given vector.
  959. * @param[in] val Vector to be requantized
  960. * @param[in] multiplier multiplier
  961. * @param[in] shift shift
  962. *
  963. * @return Returns (val * multiplier)/(2 ^ shift)
  964. *
  965. */
  966. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const q31_t multiplier, const q31_t shift)
  967. {
  968. #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
  969. const int right_shift = MIN(-1, shift);
  970. const int left_shift = shift - right_shift;
  971. const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
  972. const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
  973. int32x4_t result = vqdmulhq_n_s32(vshlq_s32(val, left_shift_dup), multiplier);
  974. result = vrshlq_s32(result, right_shift_dup);
  975. return result;
  976. #else
  977. return arm_divide_by_power_of_two_mve(
  978. arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
  979. #endif
  980. }
  981. __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
  982. {
  983. return vqrdmulhq_s32(m1, m2);
  984. }
  985. __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
  986. {
  987. const int32x4_t shift = -exponent;
  988. const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
  989. const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
  990. return vrshlq_s32(fixed_up_dividend, shift);
  991. }
  992. __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
  993. const int32x4_t multiplier,
  994. const int32x4_t shift)
  995. {
  996. #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
  997. const int32x4_t right_shift = vminq_s32(vdupq_n_s32(-1), shift);
  998. const int32x4_t left_shift = vqsubq_s32(shift, right_shift);
  999. int32x4_t result = vqdmulhq_s32(vshlq_s32(val, left_shift), multiplier);
  1000. result = vrshlq_s32(result, right_shift);
  1001. return result;
  1002. #else
  1003. const int32x4_t zz = vdupq_n_s32(0);
  1004. const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
  1005. const int32x4_t left_shift = vpselq_s32(shift, zz, p);
  1006. const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
  1007. return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
  1008. right_shift);
  1009. #endif
  1010. }
  1011. #endif
  1012. // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
  1013. __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
  1014. {
  1015. int32_t mask = 0;
  1016. int32_t shift = 24;
  1017. const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
  1018. const int32_t remainder = val_mod_minus_quarter - val;
  1019. const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
  1020. const int32_t x2 = MUL_SAT(x, x);
  1021. int32_t result = 1895147668 +
  1022. MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
  1023. #define SELECT_IF_NON_ZERO(x) \
  1024. { \
  1025. mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
  1026. result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
  1027. }
  1028. SELECT_IF_NON_ZERO(1672461947)
  1029. SELECT_IF_NON_ZERO(1302514674)
  1030. SELECT_IF_NON_ZERO(790015084)
  1031. SELECT_IF_NON_ZERO(290630308)
  1032. SELECT_IF_NON_ZERO(39332535)
  1033. SELECT_IF_NON_ZERO(720401)
  1034. SELECT_IF_NON_ZERO(242)
  1035. #undef SELECT_IF_NON_ZERO
  1036. mask = MASK_IF_ZERO(val);
  1037. return SELECT_USING_MASK(mask, NN_Q31_MAX, result);
  1038. }
  1039. __STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
  1040. {
  1041. const int32_t thresh = ((1 << (31 - exp)) - 1);
  1042. int32_t result = val << exp;
  1043. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), NN_Q31_MAX, result);
  1044. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), NN_Q31_MIN, result);
  1045. return result;
  1046. }
  1047. __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
  1048. {
  1049. const int64_t sum = (int64_t)val + (int64_t)NN_Q31_MAX;
  1050. const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
  1051. int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
  1052. const int32_t shift = (1 << 29);
  1053. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  1054. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  1055. x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
  1056. return MUL_POW2(x, 1);
  1057. }
  1058. /**
  1059. @brief Write 2 q15 elements and post increment pointer.
  1060. @param[in] dest_q15 Pointer to pointer that holds address of destination.
  1061. @param[in] src_q31 Input value to be written.
  1062. */
  1063. __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(q15_t **dest_q15, q31_t src_q31)
  1064. {
  1065. q31_t val = src_q31;
  1066. memcpy(*dest_q15, &val, 4);
  1067. *dest_q15 += 2;
  1068. }
  1069. #ifdef __cplusplus
  1070. }
  1071. #endif
  1072. #endif