arm_svdf_s8.c 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_svdf_s8.c
  21. * Description: S8 basic SVDF layer function
  22. *
  23. * $Date: 17. August 2021
  24. * $Revision: V.1.5.1
  25. *
  26. * Target Processor: Cortex-M processors
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup SVDF
  36. * @{
  37. */
  38. /*
  39. * S8 SVDF layer function for TensorFlow Lite
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. arm_status arm_svdf_s8(const cmsis_nn_context *input_ctx,
  45. const cmsis_nn_context *output_ctx,
  46. const cmsis_nn_svdf_params *svdf_params,
  47. const cmsis_nn_per_tensor_quant_params *input_quant_params,
  48. const cmsis_nn_per_tensor_quant_params *output_quant_params,
  49. const cmsis_nn_dims *input_dims,
  50. const q7_t *input_data,
  51. const cmsis_nn_dims *state_dims,
  52. q15_t *state_data,
  53. const cmsis_nn_dims *weights_feature_dims,
  54. const q7_t *weights_feature_data,
  55. const cmsis_nn_dims *weights_time_dims,
  56. const q15_t *weights_time_data,
  57. const cmsis_nn_dims *bias_dims,
  58. const q31_t *bias_data,
  59. const cmsis_nn_dims *output_dims,
  60. q7_t *output_data)
  61. {
  62. (void)bias_dims;
  63. (void)state_dims;
  64. (void)output_dims;
  65. const q31_t multiplier_in = input_quant_params->multiplier;
  66. const q31_t shift_in = input_quant_params->shift;
  67. const q31_t multiplier_out = output_quant_params->multiplier;
  68. const q31_t shift_2 = output_quant_params->shift;
  69. const int32_t zp_in = svdf_params->input_offset;
  70. const int32_t zp_out = svdf_params->output_offset;
  71. const int32_t in_activation_min = svdf_params->input_activation.min;
  72. const int32_t in_activation_max = svdf_params->input_activation.max;
  73. const int32_t out_activation_min = svdf_params->output_activation.min;
  74. const int32_t out_activation_max = svdf_params->output_activation.max;
  75. const int16_t rank = svdf_params->rank;
  76. const int32_t input_batches = input_dims->n;
  77. const int32_t input_height = input_dims->h;
  78. const int32_t feature_batches = weights_feature_dims->n;
  79. const int32_t time_batches = weights_time_dims->h;
  80. const int32_t unit_count = feature_batches / rank;
  81. if (input_ctx->buf == NULL)
  82. {
  83. return ARM_MATH_ARGUMENT_ERROR;
  84. }
  85. q31_t *buffer_a = (q31_t *)input_ctx->buf;
  86. if (output_ctx->buf == NULL)
  87. {
  88. return ARM_MATH_ARGUMENT_ERROR;
  89. }
  90. q31_t *buffer_b = (q31_t *)output_ctx->buf;
  91. memmove((q15_t *)state_data,
  92. (q15_t *)state_data + 1,
  93. (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int16_t)));
  94. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  95. {
  96. q15_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
  97. const q7_t *weight = weights_feature_data;
  98. const q7_t *input = input_data + i_batch * input_height;
  99. arm_status res = arm_nn_vec_mat_mult_t_svdf_s8(input,
  100. weight,
  101. res_ptr,
  102. -zp_in,
  103. 0,
  104. time_batches,
  105. multiplier_in,
  106. shift_in,
  107. input_height,
  108. feature_batches,
  109. in_activation_min,
  110. in_activation_max);
  111. if (res != ARM_MATH_SUCCESS)
  112. {
  113. return res;
  114. }
  115. }
  116. {
  117. q31_t *ptr_a = buffer_a;
  118. const q15_t *v2 = state_data;
  119. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  120. {
  121. const q15_t *v1 = weights_time_data;
  122. for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
  123. {
  124. *ptr_a = 0;
  125. int32_t sum = 0;
  126. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  127. int j = 0;
  128. int32_t block_count = time_batches >> 1;
  129. for (int i = 0; i < block_count; i++)
  130. {
  131. j += 2;
  132. q31_t r1 = arm_nn_read_q15x2_ia(&v1);
  133. q31_t r2 = arm_nn_read_q15x2_ia(&v2);
  134. sum = __SMLAD(r1, r2, sum);
  135. }
  136. // Process the remaining data
  137. for (; j < time_batches; j++)
  138. {
  139. sum += *v1 * *v2;
  140. v1++;
  141. v2++;
  142. }
  143. #else
  144. for (int j = 0; j < time_batches; j++)
  145. {
  146. sum += *v1 * *v2;
  147. v1++;
  148. v2++;
  149. }
  150. #endif
  151. *ptr_a = sum;
  152. ptr_a++;
  153. }
  154. }
  155. }
  156. if (bias_data)
  157. {
  158. if (unit_count == feature_batches)
  159. {
  160. for (int i = 0; i < input_batches; i++)
  161. {
  162. q31_t *output_temp = buffer_b + i * feature_batches;
  163. const q31_t *ptr_a = buffer_a + i * feature_batches;
  164. const int32_t *bi = bias_data;
  165. for (int j = 0; j < feature_batches; j++)
  166. {
  167. output_temp[j] = ptr_a[j] + bi[j];
  168. }
  169. }
  170. }
  171. else
  172. {
  173. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  174. {
  175. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  176. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  177. for (int i = 0; i < unit_count; i++)
  178. {
  179. int32_t sum = bias_data[i];
  180. for (int j = 0; j < rank; j++)
  181. {
  182. sum += *ptr_a;
  183. ptr_a++;
  184. }
  185. output_data_temp[i] = sum;
  186. }
  187. }
  188. }
  189. }
  190. else
  191. {
  192. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  193. {
  194. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  195. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  196. for (int i = 0; i < unit_count; i++)
  197. {
  198. int32_t sum = 0;
  199. for (int j = 0; j < rank; j++)
  200. {
  201. sum += *ptr_a;
  202. ptr_a++;
  203. }
  204. output_data_temp[i] = sum;
  205. }
  206. }
  207. }
  208. #if defined(ARM_MATH_MVEI)
  209. int32_t num_elements = input_batches * unit_count;
  210. const int32_t loop_count = (num_elements + 3) / 4;
  211. for (int i_op = 0; i_op < loop_count; i_op++)
  212. {
  213. mve_pred16_t p = vctp32q((uint32_t)num_elements);
  214. int32x4_t op = vldrwq_z_s32(buffer_b, p);
  215. op = arm_requantize_mve(op, multiplier_out, shift_2);
  216. op = vaddq_n_s32(op, zp_out);
  217. const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
  218. const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
  219. op = vmaxq_s32(op, min_vec);
  220. op = vminq_s32(op, max_vec);
  221. vstrbq_p_s32(output_data, op, p);
  222. output_data += 4;
  223. buffer_b += 4;
  224. num_elements -= 4;
  225. }
  226. #else
  227. for (int i = 0; i < input_batches * unit_count; i++)
  228. {
  229. output_data[i] = (q7_t)CLAMP(
  230. arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
  231. }
  232. #endif
  233. return (ARM_MATH_SUCCESS);
  234. }
  235. /**
  236. * @} end of SVDF group
  237. */