arm_fully_connected_s8.c 8.6 KB


  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_fully_connected_s8
  21. * Description: Fully connected function compatible with TF Lite.
  22. *
  23. * $Date: April 1, 2020
  24. * $Revision: V.1.5.0
  25. *
  26. * Target Processor: Cortex-M and Cortex-A cores
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_math.h"
  30. #include "arm_nnfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup FC
  36. * @{
  37. */
  38. /*
  39. * S8 basic fully-connected and matrix multiplication layer function for TensorFlow Lite
  40. *
  41. * Refer header file for details.
  42. *
  43. */
  44. #if defined(ARM_MATH_MVEI)
  45. arm_status
  46. arm_fully_connected_s8(const int8_t *input,
  47. const int8_t *kernel,
  48. const uint16_t col_dim,
  49. const uint16_t row_dim,
  50. const uint16_t nb_batches,
  51. const int32_t input_offset,
  52. const int32_t filter_offset,
  53. const int32_t out_mult,
  54. const int32_t out_shift,
  55. const int32_t output_offset,
  56. const int32_t *bias,
  57. int8_t *output,
  58. const int32_t output_activation_min,
  59. const int32_t output_activation_max,
  60. q15_t *vec_buffer)
  61. {
  62. (void)vec_buffer;
  63. const int8_t *input_a;
  64. const int32_t *bias_tmp = bias;
  65. const int8_t *weight_tmp = kernel;
  66. int32_t batch_count = nb_batches;
  67. const int16x8_t filter_offset_vec = vdupq_n_s16((int16_t)filter_offset);
  68. const int16x8_t input_offset_vec = vdupq_n_s16((int16_t)input_offset);
  69. while (batch_count)
  70. {
  71. bias_tmp = bias;
  72. weight_tmp = kernel;
  73. int cnt;
  74. cnt = row_dim >> 2;
  75. for (int out_c = 0; out_c < cnt; out_c++)
  76. {
  77. int32_t acc1 = *bias_tmp++;
  78. int32_t acc2 = *bias_tmp++;
  79. int32_t acc3 = *bias_tmp++;
  80. int32_t acc4 = *bias_tmp++;
  81. input_a = input;
  82. int16x8_t input_val, filter_val;
  83. int16x8_t tmp_a1, tmp_a2, tmp_a3, tmp_a4, tmp_b;
  84. int32x4_t acc;
  85. int32_t block_count;
  86. const int8_t *col = input_a;
  87. const int8_t *row_0 = weight_tmp;
  88. const int8_t *row_1 = weight_tmp + col_dim;
  89. const int8_t *row_2 = weight_tmp + 2 * col_dim;
  90. const int8_t *row_3 = weight_tmp + 3 * col_dim;
  91. block_count = col_dim >> 3U;
  92. while (block_count > 0U)
  93. {
  94. input_val = vldrbq_s16(col);
  95. tmp_b = vaddq_s16(input_val, input_offset_vec);
  96. filter_val = vldrbq_s16(row_0);
  97. tmp_a1 = vaddq_s16(filter_val, filter_offset_vec);
  98. acc1 = vmladavaq_s16(acc1, tmp_a1, tmp_b);
  99. filter_val = vldrbq_s16(row_1);
  100. tmp_a2 = vaddq_s16(filter_val, filter_offset_vec);
  101. acc2 = vmladavaq_s16(acc2, tmp_a2, tmp_b);
  102. filter_val = vldrbq_s16(row_2);
  103. tmp_a3 = vaddq_s16(filter_val, filter_offset_vec);
  104. acc3 = vmladavaq_s16(acc3, tmp_a3, tmp_b);
  105. filter_val = vldrbq_s16(row_3);
  106. tmp_a4 = vaddq_s16(filter_val, filter_offset_vec);
  107. acc4 = vmladavaq_s16(acc4, tmp_a4, tmp_b);
  108. col += 8;
  109. row_0 += 8;
  110. row_1 += 8;
  111. row_2 += 8;
  112. row_3 += 8;
  113. block_count--;
  114. }
  115. block_count = col_dim & 7;
  116. while (block_count > 0U)
  117. {
  118. q15_t col_ip = *col++;
  119. q7_t in_m1 = *row_0++;
  120. q7_t in_m2 = *row_1++;
  121. q7_t in_m3 = *row_2++;
  122. q7_t in_m4 = *row_3++;
  123. acc1 += (col_ip + input_offset) * (in_m1 + filter_offset);
  124. acc2 += (col_ip + input_offset) * (in_m2 + filter_offset);
  125. acc3 += (col_ip + input_offset) * (in_m3 + filter_offset);
  126. acc4 += (col_ip + input_offset) * (in_m4 + filter_offset);
  127. block_count--;
  128. }
  129. input_a = input + col_dim;
  130. weight_tmp += 4 * col_dim;
  131. acc[0] = acc1;
  132. acc[1] = acc2;
  133. acc[2] = acc3;
  134. acc[3] = acc4;
  135. acc = arm_requantize_mve(acc, out_mult, out_shift);
  136. acc = vaddq_s32(acc, vdupq_n_s32(output_offset));
  137. acc = vmaxq_s32(acc, vdupq_n_s32(output_activation_min));
  138. acc = vminq_s32(acc, vdupq_n_s32(output_activation_max));
  139. vstrbq_s32(output, acc);
  140. output += 4;
  141. }
  142. cnt = row_dim & 3;
  143. for (int out_c = 0; out_c < cnt; out_c++)
  144. {
  145. int32_t acc = *bias_tmp++;
  146. input_a = input;
  147. int16x8_t input_val, filter_val;
  148. int16x8_t tmp_a, tmp_b;
  149. int32_t block_count;
  150. const int8_t *col = input_a;
  151. const int8_t *kernel_cur = weight_tmp;
  152. block_count = col_dim >> 3U;
  153. while (block_count > 0U)
  154. {
  155. input_val = vldrbq_s16(col);
  156. filter_val = vldrbq_s16(kernel_cur);
  157. tmp_a = vaddq_s16(filter_val, filter_offset_vec);
  158. tmp_b = vaddq_s16(input_val, input_offset_vec);
  159. acc = vmladavaq_s16(acc, tmp_a, tmp_b);
  160. col += 8;
  161. kernel_cur += 8;
  162. block_count--;
  163. }
  164. block_count = col_dim & 7;
  165. while (block_count > 0U)
  166. {
  167. q15_t col_ip = *col++;
  168. q7_t in_m = *kernel_cur++;
  169. acc += (col_ip + input_offset) * (in_m + filter_offset);
  170. block_count--;
  171. }
  172. input_a += col_dim;
  173. weight_tmp += col_dim;
  174. acc = arm_nn_sat_doubling_high_mult(acc * (1 << LEFT_SHIFT(out_shift)), out_mult);
  175. acc = arm_nn_divide_by_power_of_two(acc, RIGHT_SHIFT(out_shift));
  176. acc += output_offset;
  177. acc = MAX(acc, output_activation_min);
  178. acc = MIN(acc, output_activation_max);
  179. *output++ = (int8_t)(acc);
  180. }
  181. input += col_dim;
  182. batch_count--;
  183. }
  184. return (ARM_MATH_SUCCESS);
  185. }
  186. #else
  187. arm_status
  188. arm_fully_connected_s8(const int8_t *input,
  189. const int8_t *kernel,
  190. const uint16_t col_dim,
  191. const uint16_t row_dim,
  192. const uint16_t nb_batches,
  193. const int32_t input_offset,
  194. const int32_t filter_offset,
  195. const int32_t out_mult,
  196. const int32_t out_shift,
  197. const int32_t output_offset,
  198. const int32_t *bias,
  199. int8_t *output,
  200. const int32_t output_activation_min,
  201. const int32_t output_activation_max,
  202. q15_t *vec_buffer)
  203. {
  204. (void)vec_buffer;
  205. uint16_t batch_cnt = nb_batches;
  206. while (batch_cnt)
  207. {
  208. arm_nn_vec_mat_mult_t_s8(input,
  209. kernel,
  210. bias,
  211. output,
  212. input_offset,
  213. filter_offset,
  214. output_offset,
  215. out_mult,
  216. out_shift,
  217. col_dim,
  218. row_dim,
  219. output_activation_min,
  220. output_activation_max);
  221. input += col_dim;
  222. output += row_dim;
  223. batch_cnt--;
  224. }
  225. return (ARM_MATH_SUCCESS);
  226. }
  227. #endif /* ARM_MATH_HELIUM */
  228. int32_t arm_fully_connected_s8_get_buffer_size(const uint16_t col_dim)
  229. {
  230. (void)col_dim;
  231. return 0;
  232. }
  233. /**
  234. * @} end of FC group
  235. */