arm_vector_sum_s8.c 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_vector_sum_s8
  21. * Description: Generic function for calculating vector sums
  22. *
  23. * $Date: 15 February 2024
  24. * $Revision: V.2.0.1
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup Public
  33. */
  34. /**
  35. * @addtogroup FC
  36. * @{
  37. */
  38. /*
  39. * S8 vector sum fuction in preparation for e.g. kernel sums in fully connected and matrix multiplication layer function
  40. *
  41. * Refer header file for details.
  42. *
  43. */
  44. arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
  45. const int32_t vector_cols,
  46. const int32_t vector_rows,
  47. const int8_t *vector_data,
  48. const int32_t lhs_offset,
  49. const int32_t *bias_data)
  50. {
  51. if (bias_data)
  52. {
  53. memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t));
  54. }
  55. else
  56. {
  57. memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t));
  58. }
  59. if (lhs_offset)
  60. {
  61. #if defined(ARM_MATH_MVEI)
  62. const int32_t row_loop_cnt = vector_rows / 5;
  63. for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
  64. {
  65. const int32_t col_loop_cnt = (vector_cols + 15) / 16;
  66. const int8_t *vector_0 = vector_data;
  67. const int8_t *vector_1 = vector_data + vector_cols;
  68. const int8_t *vector_2 = vector_data + 2 * vector_cols;
  69. const int8_t *vector_3 = vector_data + 3 * vector_cols;
  70. const int8_t *vector_4 = vector_data + 4 * vector_cols;
  71. int32_t vector_sum_0 = 0;
  72. int32_t vector_sum_1 = 0;
  73. int32_t vector_sum_2 = 0;
  74. int32_t vector_sum_3 = 0;
  75. int32_t vector_sum_4 = 0;
  76. uint32_t col_cnt = (uint32_t)vector_cols;
  77. for (int i = 0; i < col_loop_cnt; i++)
  78. {
  79. mve_pred16_t p = vctp8q(col_cnt);
  80. col_cnt -= 16;
  81. const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
  82. vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
  83. const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
  84. vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
  85. const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
  86. vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
  87. const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
  88. vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
  89. const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
  90. vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
  91. vector_0 += 16;
  92. vector_1 += 16;
  93. vector_2 += 16;
  94. vector_3 += 16;
  95. vector_4 += 16;
  96. }
  97. vector_data += 5 * vector_cols;
  98. vector_sum_0 *= lhs_offset;
  99. vector_sum_1 *= lhs_offset;
  100. vector_sum_2 *= lhs_offset;
  101. vector_sum_3 *= lhs_offset;
  102. vector_sum_4 *= lhs_offset;
  103. vector_sum_buf[0] += vector_sum_0;
  104. vector_sum_buf[1] += vector_sum_1;
  105. vector_sum_buf[2] += vector_sum_2;
  106. vector_sum_buf[3] += vector_sum_3;
  107. vector_sum_buf[4] += vector_sum_4;
  108. vector_sum_buf += 5;
  109. }
  110. const int32_t loop_cnt = vector_rows % 5;
  111. for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
  112. {
  113. const int32_t col_loop_cnt = (vector_cols + 15) / 16;
  114. const int8_t *vector_0 = vector_data;
  115. int32_t vector_sum_0 = 0;
  116. uint32_t col_cnt = (uint32_t)vector_cols;
  117. for (int i = 0; i < col_loop_cnt; i++)
  118. {
  119. mve_pred16_t p = vctp8q(col_cnt);
  120. col_cnt -= 16;
  121. const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
  122. vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
  123. vector_0 += 16;
  124. }
  125. vector_data += vector_cols;
  126. vector_sum_0 *= lhs_offset;
  127. vector_sum_buf[i_row_loop_cnt] += vector_sum_0;
  128. }
  129. #else
  130. for (int i = 0; i < vector_rows; i++)
  131. {
  132. int32_t sum = 0;
  133. for (int j = 0; j < vector_cols; j++)
  134. {
  135. sum += *vector_data++;
  136. }
  137. *vector_sum_buf++ += sum * lhs_offset;
  138. }
  139. #endif
  140. }
  141. return (ARM_CMSIS_NN_SUCCESS);
  142. }
  143. /**
  144. * @} end of FC group
  145. */