arm_svdf_s8.c 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_svdf_s8.c
  21. * Description: S8 basic SVDF layer function
  22. *
  23. * $Date: 14 Feb 2024
  24. * $Revision: V.6.1.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup Public
  33. */
  34. /**
  35. * @addtogroup SVDF
  36. * @{
  37. */
  38. /*
  39. * S8 SVDF layer function for TensorFlow Lite with 8 bit state tensor
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. arm_cmsis_nn_status arm_svdf_s8(const cmsis_nn_context *ctx,
  45. const cmsis_nn_context *input_ctx,
  46. const cmsis_nn_context *output_ctx,
  47. const cmsis_nn_svdf_params *svdf_params,
  48. const cmsis_nn_per_tensor_quant_params *input_quant_params,
  49. const cmsis_nn_per_tensor_quant_params *output_quant_params,
  50. const cmsis_nn_dims *input_dims,
  51. const int8_t *input_data,
  52. const cmsis_nn_dims *state_dims,
  53. int8_t *state_data,
  54. const cmsis_nn_dims *weights_feature_dims,
  55. const int8_t *weights_feature_data,
  56. const cmsis_nn_dims *weights_time_dims,
  57. const int8_t *weights_time_data,
  58. const cmsis_nn_dims *bias_dims,
  59. const int32_t *bias_data,
  60. const cmsis_nn_dims *output_dims,
  61. int8_t *output_data)
  62. {
  63. (void)bias_dims;
  64. (void)state_dims;
  65. (void)output_dims;
  66. #if defined(ARM_MATH_MVEI)
  67. if (ctx->buf == NULL)
  68. {
  69. return (ARM_CMSIS_NN_ARG_ERROR);
  70. }
  71. #endif
  72. const int32_t multiplier_in = input_quant_params->multiplier;
  73. const int32_t shift_in = input_quant_params->shift;
  74. const int32_t multiplier_out = output_quant_params->multiplier;
  75. const int32_t shift_2 = output_quant_params->shift;
  76. const int32_t zp_in = svdf_params->input_offset;
  77. const int32_t zp_out = svdf_params->output_offset;
  78. const int32_t in_activation_min = svdf_params->input_activation.min;
  79. const int32_t in_activation_max = svdf_params->input_activation.max;
  80. const int32_t out_activation_min = svdf_params->output_activation.min;
  81. const int32_t out_activation_max = svdf_params->output_activation.max;
  82. const int16_t rank = svdf_params->rank;
  83. const int32_t input_batches = input_dims->n;
  84. const int32_t input_height = input_dims->h;
  85. const int32_t feature_batches = weights_feature_dims->n;
  86. const int32_t time_batches = weights_time_dims->h;
  87. const int32_t unit_count = feature_batches / rank;
  88. if (input_ctx->buf == NULL)
  89. {
  90. return ARM_CMSIS_NN_ARG_ERROR;
  91. }
  92. int32_t *buffer_a = (int32_t *)input_ctx->buf;
  93. if (output_ctx->buf == NULL)
  94. {
  95. return ARM_CMSIS_NN_ARG_ERROR;
  96. }
  97. int32_t *buffer_b = (int32_t *)output_ctx->buf;
  98. int32_t *kernel_sum_data = (int32_t *)ctx->buf;
  99. // Left shift state
  100. memmove((int8_t *)state_data,
  101. (int8_t *)state_data + 1,
  102. (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
  103. // Matrix multiplication input * feature weight
  104. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  105. {
  106. int8_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
  107. const int8_t *input = input_data + i_batch * input_height;
  108. arm_cmsis_nn_status res = arm_nn_vec_mat_mult_t_s8(input,
  109. weights_feature_data,
  110. kernel_sum_data,
  111. NULL,
  112. res_ptr,
  113. -zp_in,
  114. 0,
  115. multiplier_in,
  116. shift_in,
  117. input_height,
  118. feature_batches,
  119. in_activation_min,
  120. in_activation_max,
  121. time_batches,
  122. 0);
  123. if (res != ARM_CMSIS_NN_SUCCESS)
  124. {
  125. return res;
  126. }
  127. }
  128. // Matrix multiplicate time weight * state tensors
  129. {
  130. int32_t *ptr_a = buffer_a;
  131. const int8_t *v2 = state_data;
  132. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  133. {
  134. const int8_t *v1 = weights_time_data;
  135. for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
  136. {
  137. *ptr_a = 0;
  138. int32_t sum = 0;
  139. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  140. // Perform matrix multiplication in blocks of four
  141. int j = 0;
  142. int32_t block_count = time_batches >> 2;
  143. for (int i = 0; i < block_count; i++)
  144. {
  145. j += 4;
  146. int32_t r1_1, r1_2, r2_1, r2_2;
  147. v1 = read_and_pad_reordered(v1, &r1_1, &r1_2);
  148. v2 = read_and_pad_reordered(v2, &r2_1, &r2_2);
  149. sum = SMLAD(r1_1, r2_1, sum);
  150. sum = SMLAD(r1_2, r2_2, sum);
  151. }
  152. // Process the remaining data
  153. for (; j < time_batches; j++)
  154. {
  155. sum += *v1 * *v2;
  156. v1++;
  157. v2++;
  158. }
  159. #else
  160. for (int j = 0; j < time_batches; j++)
  161. {
  162. sum += *v1 * *v2;
  163. v1++;
  164. v2++;
  165. }
  166. #endif
  167. *ptr_a = sum;
  168. ptr_a++;
  169. }
  170. }
  171. }
  172. if (bias_data)
  173. {
  174. if (unit_count == feature_batches)
  175. {
  176. for (int i = 0; i < input_batches; i++)
  177. {
  178. int32_t *output_temp = buffer_b + i * feature_batches;
  179. const int32_t *ptr_a = buffer_a + i * feature_batches;
  180. const int32_t *bi = bias_data;
  181. for (int j = 0; j < feature_batches; j++)
  182. {
  183. output_temp[j] = ptr_a[j] + bi[j];
  184. }
  185. }
  186. }
  187. else
  188. {
  189. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  190. {
  191. int32_t *output_data_temp = buffer_b + i_batch * unit_count;
  192. int32_t *ptr_a = buffer_a + i_batch * feature_batches;
  193. for (int i = 0; i < unit_count; i++)
  194. {
  195. int32_t sum = bias_data[i];
  196. for (int j = 0; j < rank; j++)
  197. {
  198. sum += *ptr_a;
  199. ptr_a++;
  200. }
  201. output_data_temp[i] = sum;
  202. }
  203. }
  204. }
  205. }
  206. else
  207. {
  208. for (int i_batch = 0; i_batch < input_batches; i_batch++)
  209. {
  210. int32_t *output_data_temp = buffer_b + i_batch * unit_count;
  211. int32_t *ptr_a = buffer_a + i_batch * feature_batches;
  212. for (int i = 0; i < unit_count; i++)
  213. {
  214. int32_t sum = 0;
  215. for (int j = 0; j < rank; j++)
  216. {
  217. sum += *ptr_a;
  218. ptr_a++;
  219. }
  220. output_data_temp[i] = sum;
  221. }
  222. }
  223. }
  224. #if defined(ARM_MATH_MVEI)
  225. int32_t num_elements = input_batches * unit_count;
  226. const int32_t loop_count = (num_elements + 3) / 4;
  227. for (int i_op = 0; i_op < loop_count; i_op++)
  228. {
  229. mve_pred16_t p = vctp32q((uint32_t)num_elements);
  230. int32x4_t op = vldrwq_z_s32(buffer_b, p);
  231. op = arm_requantize_mve(op, multiplier_out, shift_2);
  232. op = vaddq_n_s32(op, zp_out);
  233. const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
  234. const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
  235. op = vmaxq_s32(op, min_vec);
  236. op = vminq_s32(op, max_vec);
  237. vstrbq_p_s32(output_data, op, p);
  238. output_data += 4;
  239. buffer_b += 4;
  240. num_elements -= 4;
  241. }
  242. #else
  243. for (int i = 0; i < input_batches * unit_count; i++)
  244. {
  245. output_data[i] = (int8_t)CLAMP(
  246. arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
  247. }
  248. #endif
  249. return (ARM_CMSIS_NN_SUCCESS);
  250. }
  251. /**
  252. * @} end of SVDF group
  253. */