arm_nn_mat_mult_s8.c 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_s8.c
  21. * Description: General Matrix-multiplication function
  22. *
  23. * $Date: March 6, 2020
  24. * $Revision: V.2.0.2
  25. *
  26. * Target Processor: Cortex-M cores
  27. * -------------------------------------------------------------------- */
  28. #include "arm_math.h"
  29. #include "arm_nnfunctions.h"
  30. /*
  31. * s8 General matrix multiplication function with per-channel requantization for upto 4 column batches.
  32. *
  33. * Refer header file for details.
  34. *
  35. */
  36. q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
  37. const q7_t *input_col,
  38. const uint16_t output_ch,
  39. const uint16_t col_batches,
  40. const int32_t *output_shift,
  41. const int32_t *output_mult,
  42. const int32_t out_offset,
  43. const int32_t col_offset,
  44. const int32_t row_offset,
  45. const int16_t activation_min,
  46. const int16_t activation_max,
  47. const uint16_t row_len,
  48. const int32_t *const bias,
  49. q7_t *out)
  50. {
  51. #if defined(ARM_MATH_MVEI)
  52. (void)row_offset;
  53. if (col_batches == 4)
  54. {
  55. for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
  56. {
  57. int32_t row_len_tmp = row_len;
  58. const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
  59. const int8_t *ip_c0 = input_col;
  60. const int8_t *ip_c1 = input_col + row_len;
  61. const int8_t *ip_c2 = input_col + (2 * row_len);
  62. const int8_t *ip_c3 = input_col + (3 * row_len);
  63. int32_t acc_0 = bias[i_out_ch];
  64. int32_t acc_1 = bias[i_out_ch];
  65. int32_t acc_2 = bias[i_out_ch];
  66. int32_t acc_3 = bias[i_out_ch];
  67. const int32_t row_loop_cnt = (row_len + 7) / 8;
  68. for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
  69. {
  70. mve_pred16_t p = vctp16q(row_len_tmp);
  71. const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
  72. row_len_tmp -= 8;
  73. int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
  74. ip_r0 += 8;
  75. int16x8_t c0 = vldrbq_z_s16(ip_c0, p);
  76. ip_c0 += 8;
  77. c0 = vaddq_m_s16(vuninitializedq_s16(), c0, offset, p);
  78. int16x8_t c1 = vldrbq_z_s16(ip_c1, p);
  79. ip_c1 += 8;
  80. c1 = vaddq_m_s16(vuninitializedq_s16(), c1, offset, p);
  81. int16x8_t c2 = vldrbq_z_s16(ip_c2, p);
  82. ip_c2 += 8;
  83. c2 = vaddq_m_s16(vuninitializedq_s16(), c2, offset, p);
  84. int16x8_t c3 = vldrbq_z_s16(ip_c3, p);
  85. ip_c3 += 8;
  86. c3 = vaddq_m_s16(vuninitializedq_s16(), c3, offset, p);
  87. acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
  88. acc_1 = vmladavaq_p_s16(acc_1, r0, c1, p);
  89. acc_2 = vmladavaq_p_s16(acc_2, r0, c2, p);
  90. acc_3 = vmladavaq_p_s16(acc_3, r0, c3, p);
  91. }
  92. int32x4_t res = {acc_0, acc_1, acc_2, acc_3};
  93. res = arm_requantize_mve(res, output_mult[i_out_ch], output_shift[i_out_ch]);
  94. res = vaddq_n_s32(res, out_offset);
  95. res = vmaxq_s32(res, vdupq_n_s32(activation_min));
  96. res = vminq_s32(res, vdupq_n_s32(activation_max));
  97. const uint32x4_t scatter_offset = {0, output_ch, output_ch * 2, output_ch * 3};
  98. vstrbq_scatter_offset_s32(&out[i_out_ch], scatter_offset, res);
  99. }
  100. out += 4 * output_ch;
  101. }
  102. else
  103. {
  104. for (int i_col_batch = (col_batches & ~0x3); i_col_batch < (col_batches & 0x3); i_col_batch++)
  105. {
  106. for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
  107. {
  108. int32_t row_len_tmp = row_len;
  109. const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
  110. const int8_t *ip_c0 = input_col + (i_col_batch * row_len);
  111. int32_t acc_0 = bias[i_out_ch];
  112. const int32_t row_loop_cnt = (row_len + 7) / 8;
  113. for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
  114. {
  115. const mve_pred16_t p = vctp16q(row_len_tmp);
  116. const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
  117. row_len_tmp -= 8;
  118. int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
  119. ip_r0 += 8;
  120. int16x8_t c0 = vldrbq_z_s16(ip_c0, p);
  121. ip_c0 += 8;
  122. c0 = vaddq_m_s16(vuninitializedq_s16(), c0, offset, p);
  123. acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
  124. }
  125. acc_0 = arm_nn_requantize(acc_0, output_mult[i_out_ch], output_shift[i_out_ch]);
  126. acc_0 += out_offset;
  127. acc_0 = MAX(acc_0, activation_min);
  128. acc_0 = MIN(acc_0, activation_max);
  129. out[i_out_ch] = (q7_t)acc_0;
  130. }
  131. out += output_ch;
  132. }
  133. }
  134. return out;
  135. #else
  136. (void)input_row;
  137. (void)input_col;
  138. (void)output_ch;
  139. (void)col_batches;
  140. (void)output_shift;
  141. (void)output_mult;
  142. (void)out_offset;
  143. (void)col_offset;
  144. (void)row_offset;
  145. (void)activation_min;
  146. (void)activation_max;
  147. (void)row_len;
  148. (void)bias;
  149. (void)out;
  150. return NULL;
  151. #endif
  152. }