arm_svdf_s8.c 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /*
  2. * Copyright (C) 2010-2022 Arm Limited or its affiliates.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_svdf_s8.c
  21. * Description: S8 basic SVDF layer function
  22. *
  23. * $Date: 28 April 2022
  24. * $Revision: V.3.0.1
  25. *
  26. * Target Processor: Cortex-M processors
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup SVDF
  36. * @{
  37. */
  38. /*
  39. * S8 SVDF layer function for TensorFlow Lite with 8 bit state tensor
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. arm_status arm_svdf_s8(const cmsis_nn_context *input_ctx,
  45. const cmsis_nn_context *output_ctx,
  46. const cmsis_nn_svdf_params *svdf_params,
  47. const cmsis_nn_per_tensor_quant_params *input_quant_params,
  48. const cmsis_nn_per_tensor_quant_params *output_quant_params,
  49. const cmsis_nn_dims *input_dims,
  50. const q7_t *input_data,
  51. const cmsis_nn_dims *state_dims,
  52. q7_t *state_data,
  53. const cmsis_nn_dims *weights_feature_dims,
  54. const q7_t *weights_feature_data,
  55. const cmsis_nn_dims *weights_time_dims,
  56. const q7_t *weights_time_data,
  57. const cmsis_nn_dims *bias_dims,
  58. const q31_t *bias_data,
  59. const cmsis_nn_dims *output_dims,
  60. q7_t *output_data)
  61. {
  62. (void)bias_dims;
  63. (void)state_dims;
  64. (void)output_dims;
  65. const q31_t multiplier_in = input_quant_params->multiplier;
  66. const q31_t shift_in = input_quant_params->shift;
  67. const q31_t multiplier_out = output_quant_params->multiplier;
  68. const q31_t shift_2 = output_quant_params->shift;
  69. const int32_t zp_in = svdf_params->input_offset;
  70. const int32_t zp_out = svdf_params->output_offset;
  71. const int32_t in_activation_min = svdf_params->input_activation.min;
  72. const int32_t in_activation_max = svdf_params->input_activation.max;
  73. const int32_t out_activation_min = svdf_params->output_activation.min;
  74. const int32_t out_activation_max = svdf_params->output_activation.max;
  75. const int16_t rank = svdf_params->rank;
  76. const int32_t input_batches = input_dims->n;
  77. const int32_t input_height = input_dims->h;
  78. const int32_t feature_batches = weights_feature_dims->n;
  79. const int32_t time_batches = weights_time_dims->h;
  80. const int32_t unit_count = feature_batches / rank;
  81. if (input_ctx->buf == NULL)
  82. {
  83. return ARM_MATH_ARGUMENT_ERROR;
  84. }
  85. q31_t *buffer_a = (q31_t *)input_ctx->buf;
  86. if (output_ctx->buf == NULL)
  87. {
  88. return ARM_MATH_ARGUMENT_ERROR;
  89. }
  90. q31_t *buffer_b = (q31_t *)output_ctx->buf;
  91. // Left shift state
  92. memmove((int8_t *)state_data,
  93. (int8_t *)state_data + 1,
  94. (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
  95. // Matrix multiplication input * feature weight
  96. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  97. {
  98. q7_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
  99. const q7_t *weight = weights_feature_data;
  100. const q7_t *input = input_data + i_batch * input_height;
  101. arm_status res = arm_nn_vec_mat_mult_t_s8(input,
  102. weight,
  103. NULL,
  104. res_ptr,
  105. -zp_in,
  106. 0,
  107. 0,
  108. multiplier_in,
  109. shift_in,
  110. input_height,
  111. feature_batches,
  112. in_activation_min,
  113. in_activation_max,
  114. time_batches);
  115. if (res != ARM_MATH_SUCCESS)
  116. {
  117. return res;
  118. }
  119. }
  120. // Matrix multiplicate time weight * state tensors
  121. {
  122. q31_t *ptr_a = buffer_a;
  123. const int8_t *v2 = state_data;
  124. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  125. {
  126. const int8_t *v1 = weights_time_data;
  127. for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
  128. {
  129. *ptr_a = 0;
  130. int32_t sum = 0;
  131. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  132. // Perform matrix multiplication in blocks of four
  133. int j = 0;
  134. int32_t block_count = time_batches >> 2;
  135. for (int i = 0; i < block_count; i++)
  136. {
  137. j += 4;
  138. q31_t r1_1, r1_2, r2_1, r2_2;
  139. v1 = read_and_pad_reordered(v1, &r1_1, &r1_2);
  140. v2 = read_and_pad_reordered(v2, &r2_1, &r2_2);
  141. sum = __SMLAD(r1_1, r2_1, sum);
  142. sum = __SMLAD(r1_2, r2_2, sum);
  143. }
  144. // Process the remaining data
  145. for (; j < time_batches; j++)
  146. {
  147. sum += *v1 * *v2;
  148. v1++;
  149. v2++;
  150. }
  151. #else
  152. for (int j = 0; j < time_batches; j++)
  153. {
  154. sum += *v1 * *v2;
  155. v1++;
  156. v2++;
  157. }
  158. #endif
  159. *ptr_a = sum;
  160. ptr_a++;
  161. }
  162. }
  163. }
  164. if (bias_data)
  165. {
  166. if (unit_count == feature_batches)
  167. {
  168. for (int i = 0; i < input_batches; i++)
  169. {
  170. q31_t *output_temp = buffer_b + i * feature_batches;
  171. const q31_t *ptr_a = buffer_a + i * feature_batches;
  172. const int32_t *bi = bias_data;
  173. for (int j = 0; j < feature_batches; j++)
  174. {
  175. output_temp[j] = ptr_a[j] + bi[j];
  176. }
  177. }
  178. }
  179. else
  180. {
  181. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  182. {
  183. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  184. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  185. for (int i = 0; i < unit_count; i++)
  186. {
  187. int32_t sum = bias_data[i];
  188. for (int j = 0; j < rank; j++)
  189. {
  190. sum += *ptr_a;
  191. ptr_a++;
  192. }
  193. output_data_temp[i] = sum;
  194. }
  195. }
  196. }
  197. }
  198. else
  199. {
  200. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  201. {
  202. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  203. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  204. for (int i = 0; i < unit_count; i++)
  205. {
  206. int32_t sum = 0;
  207. for (int j = 0; j < rank; j++)
  208. {
  209. sum += *ptr_a;
  210. ptr_a++;
  211. }
  212. output_data_temp[i] = sum;
  213. }
  214. }
  215. }
  216. #if defined(ARM_MATH_MVEI)
  217. int32_t num_elements = input_batches * unit_count;
  218. const int32_t loop_count = (num_elements + 3) / 4;
  219. for (int i_op = 0; i_op < loop_count; i_op++)
  220. {
  221. mve_pred16_t p = vctp32q((uint32_t)num_elements);
  222. int32x4_t op = vldrwq_z_s32(buffer_b, p);
  223. op = arm_requantize_mve(op, multiplier_out, shift_2);
  224. op = vaddq_n_s32(op, zp_out);
  225. const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
  226. const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
  227. op = vmaxq_s32(op, min_vec);
  228. op = vminq_s32(op, max_vec);
  229. vstrbq_p_s32(output_data, op, p);
  230. output_data += 4;
  231. buffer_b += 4;
  232. num_elements -= 4;
  233. }
  234. #else
  235. for (int i = 0; i < input_batches * unit_count; i++)
  236. {
  237. output_data[i] = (q7_t)CLAMP(
  238. arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
  239. }
  240. #endif
  241. return (ARM_MATH_SUCCESS);
  242. }
  243. /**
  244. * @} end of SVDF group
  245. */