arm_svdf_s8.c 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_svdf_s8.c
  21. * Description: S8 basic SVDF layer function
  22. *
  23. * $Date: 15. April 2021
  24. * $Revision: V.1.5.0
  25. *
  26. * Target Processor: Cortex-M processors
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup SVDF
  36. * @{
  37. */
  38. /*
  39. * S8 SVDF layer function for TensorFlow Lite
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. arm_status arm_svdf_s8(const cmsis_nn_context *input_ctx,
  45. const cmsis_nn_context *output_ctx,
  46. const cmsis_nn_svdf_params *svdf_params,
  47. const cmsis_nn_per_tensor_quant_params *input_quant_params,
  48. const cmsis_nn_per_tensor_quant_params *output_quant_params,
  49. const cmsis_nn_dims *input_dims,
  50. const q7_t *input_data,
  51. const cmsis_nn_dims *state_dims,
  52. q15_t *state_data,
  53. const cmsis_nn_dims *weights_feature_dims,
  54. const q7_t *weights_feature_data,
  55. const cmsis_nn_dims *weights_time_dims,
  56. const q15_t *weights_time_data,
  57. const cmsis_nn_dims *bias_dims,
  58. const q31_t *bias_data,
  59. const cmsis_nn_dims *output_dims,
  60. q7_t *output_data)
  61. {
  62. (void)bias_dims;
  63. (void)state_dims;
  64. (void)output_dims;
  65. const q31_t multiplier_in = input_quant_params->multiplier;
  66. const q31_t shift_in = input_quant_params->shift;
  67. const q31_t multiplier_out = output_quant_params->multiplier;
  68. const q31_t shift_2 = output_quant_params->shift;
  69. const int32_t zp_in = svdf_params->input_offset;
  70. const int32_t zp_out = svdf_params->output_offset;
  71. const int32_t in_activation_min = svdf_params->input_activation.min;
  72. const int32_t in_activation_max = svdf_params->input_activation.max;
  73. const int32_t out_activation_min = svdf_params->output_activation.min;
  74. const int32_t out_activation_max = svdf_params->output_activation.max;
  75. const int16_t rank = svdf_params->rank;
  76. const int32_t input_batches = input_dims->n;
  77. const int32_t input_height = input_dims->h;
  78. const int32_t feature_batches = weights_feature_dims->n;
  79. const int32_t time_batches = weights_time_dims->h;
  80. const int32_t unit_count = feature_batches / rank;
  81. q31_t *buffer_a = (q31_t *)input_ctx->buf;
  82. q31_t *buffer_b = (q31_t *)output_ctx->buf;
  83. memmove((q15_t *)state_data,
  84. (q15_t *)state_data + 1,
  85. (size_t)(input_batches * feature_batches * time_batches * (int32_t)sizeof(int16_t)));
  86. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  87. {
  88. q15_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
  89. const q7_t *weight = weights_feature_data;
  90. const q7_t *input = input_data + i_batch * input_height;
  91. arm_status res = arm_nn_vec_mat_mult_t_svdf_s8(input,
  92. weight,
  93. res_ptr,
  94. -zp_in,
  95. 0,
  96. time_batches,
  97. multiplier_in,
  98. shift_in,
  99. input_height,
  100. feature_batches,
  101. in_activation_min,
  102. in_activation_max);
  103. if (res != ARM_MATH_SUCCESS)
  104. {
  105. return res;
  106. }
  107. }
  108. {
  109. q31_t *ptr_a = buffer_a;
  110. const q15_t *v2 = state_data;
  111. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  112. {
  113. const q15_t *v1 = weights_time_data;
  114. for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
  115. {
  116. *ptr_a = 0;
  117. int32_t sum = 0;
  118. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  119. int j = 0;
  120. int32_t block_count = time_batches >> 1;
  121. for (int i = 0; i < block_count; i++)
  122. {
  123. j += 2;
  124. q31_t r1 = arm_nn_read_q15x2_ia(&v1);
  125. q31_t r2 = arm_nn_read_q15x2_ia(&v2);
  126. sum = __SMLAD(r1, r2, sum);
  127. }
  128. // Process the remaining data
  129. for (; j < time_batches; j++)
  130. {
  131. sum += *v1 * *v2;
  132. v1++;
  133. v2++;
  134. }
  135. #else
  136. for (int j = 0; j < time_batches; j++)
  137. {
  138. sum += *v1 * *v2;
  139. v1++;
  140. v2++;
  141. }
  142. #endif
  143. *ptr_a = sum;
  144. ptr_a++;
  145. }
  146. }
  147. }
  148. if (bias_data)
  149. {
  150. if (unit_count == feature_batches)
  151. {
  152. for (int i = 0; i < input_batches; i++)
  153. {
  154. q31_t *output_temp = buffer_b + i * feature_batches;
  155. const q31_t *ptr_a = buffer_a + i * feature_batches;
  156. const int32_t *bi = bias_data;
  157. for (int j = 0; j < feature_batches; j++)
  158. {
  159. output_temp[j] = ptr_a[j] + bi[j];
  160. }
  161. }
  162. }
  163. else
  164. {
  165. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  166. {
  167. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  168. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  169. for (int i = 0; i < unit_count; i++)
  170. {
  171. int32_t sum = bias_data[i];
  172. for (int j = 0; j < rank; j++)
  173. {
  174. sum += *ptr_a;
  175. ptr_a++;
  176. }
  177. output_data_temp[i] = sum;
  178. }
  179. }
  180. }
  181. }
  182. else
  183. {
  184. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  185. {
  186. q31_t *output_data_temp = buffer_b + i_batch * unit_count;
  187. q31_t *ptr_a = buffer_a + i_batch * feature_batches;
  188. for (int i = 0; i < unit_count; i++)
  189. {
  190. int32_t sum = 0;
  191. for (int j = 0; j < rank; j++)
  192. {
  193. sum += *ptr_a;
  194. ptr_a++;
  195. }
  196. output_data_temp[i] = sum;
  197. }
  198. }
  199. }
  200. #if defined(ARM_MATH_MVEI)
  201. int32_t num_elements = input_batches * unit_count;
  202. const int32_t loop_count = (num_elements + 3) / 4;
  203. for (int i_op = 0; i_op < loop_count; i_op++)
  204. {
  205. mve_pred16_t p = vctp32q((uint32_t)num_elements);
  206. int32x4_t op = vldrwq_z_s32(buffer_b, p);
  207. op = arm_requantize_mve(op, multiplier_out, shift_2);
  208. op = vaddq_n_s32(op, zp_out);
  209. const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
  210. const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
  211. op = vmaxq_s32(op, min_vec);
  212. op = vminq_s32(op, max_vec);
  213. vstrbq_p_s32(output_data, op, p);
  214. output_data += 4;
  215. buffer_b += 4;
  216. num_elements -= 4;
  217. }
  218. #else
  219. for (int i = 0; i < input_batches * unit_count; i++)
  220. {
  221. output_data[i] = (q7_t)CLAMP(
  222. arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
  223. }
  224. #endif
  225. return (ARM_MATH_SUCCESS);
  226. }
  227. /**
  228. * @} end of SVDF group
  229. */