arm_nn_mat_mult_s8.c 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_s8.c
  21. * Description: General Matrix-multiplication function
  22. *
  23. * $Date: 27. October 2021
  24. * $Revision: V.2.0.6
  25. *
  26. * Target Processor: Cortex-M cores
  27. * -------------------------------------------------------------------- */
  28. #include "arm_nnsupportfunctions.h"
  29. /*
  30. * s8 General matrix multiplication function with per-channel requantization for upto 4 column batches.
  31. *
  32. * Refer header file for details.
  33. *
  34. */
  35. q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
  36. const q7_t *input_col,
  37. const uint16_t output_ch,
  38. const uint16_t col_batches,
  39. const int32_t *output_shift,
  40. const int32_t *output_mult,
  41. const int32_t out_offset,
  42. const int32_t col_offset,
  43. const int32_t row_offset,
  44. const int16_t activation_min,
  45. const int16_t activation_max,
  46. const uint16_t row_len,
  47. const int32_t *const bias,
  48. q7_t *out)
  49. {
  50. #if defined(ARM_MATH_MVEI)
  51. (void)row_offset;
  52. if (col_batches == 4)
  53. {
  54. for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
  55. {
  56. int32_t row_len_tmp = row_len;
  57. const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
  58. const int8_t *ip_c0 = input_col;
  59. const int8_t *ip_c1 = input_col + row_len;
  60. const int8_t *ip_c2 = input_col + (2 * row_len);
  61. const int8_t *ip_c3 = input_col + (3 * row_len);
  62. int32_t acc_0 = 0;
  63. int32_t acc_1 = 0;
  64. int32_t acc_2 = 0;
  65. int32_t acc_3 = 0;
  66. const int32_t row_loop_cnt = (row_len + 7) / 8;
  67. for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
  68. {
  69. mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
  70. const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
  71. row_len_tmp -= 8;
  72. int16x8_t c0 = vldrbq_s16(ip_c0);
  73. ip_c0 += 8;
  74. c0 = vaddq_s16(c0, offset);
  75. int16x8_t c1 = vldrbq_s16(ip_c1);
  76. ip_c1 += 8;
  77. c1 = vaddq_s16(c1, offset);
  78. int16x8_t c2 = vldrbq_s16(ip_c2);
  79. ip_c2 += 8;
  80. c2 = vaddq_s16(c2, offset);
  81. int16x8_t c3 = vldrbq_s16(ip_c3);
  82. ip_c3 += 8;
  83. c3 = vaddq_s16(c3, offset);
  84. int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
  85. ip_r0 += 8;
  86. acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
  87. acc_1 = vmladavaq_p_s16(acc_1, r0, c1, p);
  88. acc_2 = vmladavaq_p_s16(acc_2, r0, c2, p);
  89. acc_3 = vmladavaq_p_s16(acc_3, r0, c3, p);
  90. }
  91. int32x4_t res = {acc_0, acc_1, acc_2, acc_3};
  92. if (bias)
  93. {
  94. res = vaddq_n_s32(res, bias[i_out_ch]);
  95. }
  96. res = arm_requantize_mve(res, output_mult[i_out_ch], output_shift[i_out_ch]);
  97. res = vaddq_n_s32(res, out_offset);
  98. res = vmaxq_s32(res, vdupq_n_s32(activation_min));
  99. res = vminq_s32(res, vdupq_n_s32(activation_max));
  100. const uint32x4_t scatter_offset = {0, output_ch, output_ch * 2, output_ch * 3};
  101. vstrbq_scatter_offset_s32(&out[i_out_ch], scatter_offset, res);
  102. }
  103. out += 4 * output_ch;
  104. }
  105. else
  106. {
  107. for (int i_col_batch = (col_batches & ~0x3); i_col_batch < (col_batches & 0x3); i_col_batch++)
  108. {
  109. for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
  110. {
  111. int32_t row_len_tmp = row_len;
  112. const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
  113. const int8_t *ip_c0 = input_col + (i_col_batch * row_len);
  114. int32_t acc_0 = 0;
  115. const int32_t row_loop_cnt = (row_len + 7) / 8;
  116. for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
  117. {
  118. const mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
  119. const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
  120. row_len_tmp -= 8;
  121. int16x8_t c0 = vldrbq_s16(ip_c0);
  122. ip_c0 += 8;
  123. c0 = vaddq_s16(c0, offset);
  124. int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
  125. ip_r0 += 8;
  126. acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
  127. }
  128. if (bias)
  129. {
  130. acc_0 += bias[i_out_ch];
  131. }
  132. acc_0 = arm_nn_requantize(acc_0, output_mult[i_out_ch], output_shift[i_out_ch]);
  133. acc_0 += out_offset;
  134. acc_0 = MAX(acc_0, activation_min);
  135. acc_0 = MIN(acc_0, activation_max);
  136. out[i_out_ch] = (q7_t)acc_0;
  137. }
  138. out += output_ch;
  139. }
  140. }
  141. return out;
  142. #else
  143. (void)input_row;
  144. (void)input_col;
  145. (void)output_ch;
  146. (void)col_batches;
  147. (void)output_shift;
  148. (void)output_mult;
  149. (void)out_offset;
  150. (void)col_offset;
  151. (void)row_offset;
  152. (void)activation_min;
  153. (void)activation_max;
  154. (void)row_len;
  155. (void)bias;
  156. (void)out;
  157. return NULL;
  158. #endif
  159. }