arm_avgpool_s8.c 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_avgpool_s8.c
  21. * Description: Pooling function implementations
  22. *
  23. * $Date: March 4,2020
  24. * $Revision: V.1.1.1
  25. *
  26. * Target Processor: Cortex-M CPUs
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_math.h"
  30. #include "arm_nnfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup Pooling
  36. * @{
  37. */
  38. /*
  39. * s8 average pooling function
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. #if defined(ARM_MATH_MVEI)
  45. arm_status arm_avgpool_s8(const int dim_src_height,
  46. const int dim_src_width,
  47. const int dim_dst_height,
  48. const int dim_dst_width,
  49. const int stride_height,
  50. const int stride_width,
  51. const int dim_kernel_height,
  52. const int dim_kernel_width,
  53. const int padding_height,
  54. const int padding_width,
  55. const int act_min,
  56. const int act_max,
  57. const int ch_src,
  58. int8_t *src,
  59. int16_t *bufferA,
  60. int8_t *dst)
  61. {
  62. (void)bufferA;
  63. int32_t i_x, i_y;
  64. int32_t k_x, k_y;
  65. for (i_y = 0; i_y < dim_dst_height; i_y++)
  66. {
  67. for (i_x = 0; i_x < dim_dst_width; i_x++)
  68. {
  69. int32_t k_y_start,k_y_end;
  70. int32_t k_x_start,k_x_end;
  71. int32_t chCnt;
  72. int8_t *pTmp,*pTmpInner;
  73. int8_t *pDst;
  74. k_y_start = MAX(0, i_y * stride_height - padding_height);
  75. k_y_end = MIN(i_y * stride_height - padding_height + dim_kernel_height,dim_src_height);
  76. k_x_start = MAX(0,i_x * stride_width - padding_width);
  77. k_x_end = MIN(i_x * stride_width - padding_width + dim_kernel_width, dim_src_width);
  78. pTmp = src;
  79. pDst = &dst[ch_src * (i_x + i_y * dim_dst_width)];
  80. chCnt = ch_src >> 4;
  81. while(chCnt > 0)
  82. {
  83. int32x4_t sumV1,sumV2,sumV3,sumV4;
  84. int8x16_t tempV;
  85. int16x8_t tempVLO, tempVHI;
  86. int32x4_t tempVLOLO, tempVLOHI, tempVHILO, tempVHIHI;
  87. int32_t count = 0;
  88. sumV1 = vdupq_n_s32(0);
  89. sumV2 = vdupq_n_s32(0);
  90. sumV3 = vdupq_n_s32(0);
  91. sumV4 = vdupq_n_s32(0);
  92. for (k_y = k_y_start; k_y < k_y_end; k_y++)
  93. {
  94. for (k_x = k_x_start; k_x < k_x_end; k_x++)
  95. {
  96. pTmpInner = pTmp + (ch_src * (k_x + k_y * dim_src_width));
  97. tempV = vldrbq_s8 (pTmpInner);
  98. tempVLO = vmovlbq_s8(tempV);
  99. tempVHI = vmovltq_s8(tempV);
  100. tempVLOLO = vmovlbq_s16(tempVLO);
  101. tempVLOHI = vmovltq_s16(tempVLO);
  102. tempVHILO = vmovlbq_s16(tempVHI);
  103. tempVHIHI = vmovltq_s16(tempVHI);
  104. sumV1 = vaddq_s32(sumV1,tempVLOLO);
  105. sumV2 = vaddq_s32(sumV2,tempVLOHI);
  106. sumV3 = vaddq_s32(sumV3,tempVHILO);
  107. sumV4 = vaddq_s32(sumV4,tempVHIHI);
  108. count++;
  109. }
  110. }
  111. sumV1[0] = sumV1[0] > 0 ? (sumV1[0] + count / 2) / count : (sumV1[0] - count / 2) / count;
  112. sumV1[1] = sumV1[1] > 0 ? (sumV1[1] + count / 2) / count : (sumV1[1] - count / 2) / count;
  113. sumV1[2] = sumV1[2] > 0 ? (sumV1[2] + count / 2) / count : (sumV1[2] - count / 2) / count;
  114. sumV1[3] = sumV1[3] > 0 ? (sumV1[3] + count / 2) / count : (sumV1[3] - count / 2) / count;
  115. sumV2[0] = sumV2[0] > 0 ? (sumV2[0] + count / 2) / count : (sumV2[0] - count / 2) / count;
  116. sumV2[1] = sumV2[1] > 0 ? (sumV2[1] + count / 2) / count : (sumV2[1] - count / 2) / count;
  117. sumV2[2] = sumV2[2] > 0 ? (sumV2[2] + count / 2) / count : (sumV2[2] - count / 2) / count;
  118. sumV2[3] = sumV2[3] > 0 ? (sumV2[3] + count / 2) / count : (sumV2[3] - count / 2) / count;
  119. sumV3[0] = sumV3[0] > 0 ? (sumV3[0] + count / 2) / count : (sumV3[0] - count / 2) / count;
  120. sumV3[1] = sumV3[1] > 0 ? (sumV3[1] + count / 2) / count : (sumV3[1] - count / 2) / count;
  121. sumV3[2] = sumV3[2] > 0 ? (sumV3[2] + count / 2) / count : (sumV3[2] - count / 2) / count;
  122. sumV3[3] = sumV3[3] > 0 ? (sumV3[3] + count / 2) / count : (sumV3[3] - count / 2) / count;
  123. sumV4[0] = sumV4[0] > 0 ? (sumV4[0] + count / 2) / count : (sumV4[0] - count / 2) / count;
  124. sumV4[1] = sumV4[1] > 0 ? (sumV4[1] + count / 2) / count : (sumV4[1] - count / 2) / count;
  125. sumV4[2] = sumV4[2] > 0 ? (sumV4[2] + count / 2) / count : (sumV4[2] - count / 2) / count;
  126. sumV4[3] = sumV4[3] > 0 ? (sumV4[3] + count / 2) / count : (sumV4[3] - count / 2) / count;
  127. sumV1 = vmaxq_s32(sumV1, vdupq_n_s32(act_min));
  128. sumV1 = vminq_s32(sumV1, vdupq_n_s32(act_max));
  129. sumV2 = vmaxq_s32(sumV2, vdupq_n_s32(act_min));
  130. sumV2 = vminq_s32(sumV2, vdupq_n_s32(act_max));
  131. sumV3 = vmaxq_s32(sumV3, vdupq_n_s32(act_min));
  132. sumV3 = vminq_s32(sumV3, vdupq_n_s32(act_max));
  133. sumV4 = vmaxq_s32(sumV4, vdupq_n_s32(act_min));
  134. sumV4 = vminq_s32(sumV4, vdupq_n_s32(act_max));
  135. tempVLO = vmovnbq_s32(tempVLO,sumV1);
  136. tempVLO = vmovntq_s32(tempVLO,sumV2);
  137. tempVHI = vmovnbq_s32(tempVHI,sumV3);
  138. tempVHI = vmovntq_s32(tempVHI,sumV4);
  139. tempV = vmovnbq_s16(tempV,tempVLO);
  140. tempV = vmovntq_s16(tempV,tempVHI);
  141. vstrbq_s8(pDst,tempV);
  142. pDst += 16;
  143. chCnt --;
  144. pTmp += 16;
  145. }
  146. chCnt = ch_src & 0xF;
  147. while(chCnt > 0)
  148. {
  149. int32_t sum = 0;
  150. int32_t count = 0;
  151. for (k_y = k_y_start; k_y < k_y_end; k_y++)
  152. {
  153. for (k_x = k_x_start; k_x < k_x_end; k_x++)
  154. {
  155. sum += pTmp[ch_src * (k_x + k_y * dim_src_width)];
  156. count++;
  157. }
  158. }
  159. sum = sum > 0 ? (sum + count / 2) / count : (sum - count / 2) / count;
  160. sum = MAX(sum, act_min);
  161. sum = MIN(sum, act_max);
  162. *pDst++ = sum;
  163. chCnt --;
  164. pTmp++;
  165. }
  166. }
  167. }
  168. return ARM_MATH_SUCCESS;
  169. }
  170. #else
  171. arm_status arm_avgpool_s8(const int dim_src_height,
  172. const int dim_src_width,
  173. const int dim_dst_height,
  174. const int dim_dst_width,
  175. const int stride_height,
  176. const int stride_width,
  177. const int dim_kernel_height,
  178. const int dim_kernel_width,
  179. const int padding_height,
  180. const int padding_width,
  181. const int act_min,
  182. const int act_max,
  183. const int ch_src,
  184. int8_t *src,
  185. int16_t *bufferA,
  186. int8_t *dst)
  187. {
  188. /* Reference C code adapted from CMSIS-NN arm_avepool_q7_HWC.
  189. */
  190. (void)bufferA;
  191. int16_t i_ch_in, i_x, i_y;
  192. int16_t k_x, k_y;
  193. for (i_y = 0; i_y < dim_dst_height; i_y++)
  194. {
  195. for (i_x = 0; i_x < dim_dst_width; i_x++)
  196. {
  197. for (i_ch_in = 0; i_ch_in < ch_src; i_ch_in++)
  198. {
  199. int sum = 0;
  200. int count = 0;
  201. for (k_y = i_y * stride_height - padding_height; k_y < i_y * stride_height - padding_height + dim_kernel_height; k_y++)
  202. {
  203. for (k_x = i_x * stride_width - padding_width; k_x < i_x * stride_width - padding_width + dim_kernel_width; k_x++)
  204. {
  205. if (k_y >= 0 && k_x >= 0 && k_y < dim_src_height && k_x < dim_src_width)
  206. {
  207. sum += src[i_ch_in + ch_src * (k_x + k_y * dim_src_width)];
  208. count++;
  209. }
  210. }
  211. }
  212. sum = sum > 0 ? (sum + count / 2) / count : (sum - count / 2) / count;
  213. sum = MAX(sum, act_min);
  214. sum = MIN(sum, act_max);
  215. dst[i_ch_in + ch_src * (i_x + i_y * dim_dst_width)] = sum;
  216. }
  217. }
  218. }
  219. return ARM_MATH_SUCCESS;
  220. }
  221. #endif /* ARM_MATH_MVEI */
  222. int32_t arm_avgpool_s8_get_buffer_size(const int dim_dst_width,
  223. const int ch_src)
  224. {
  225. (void)dim_dst_width;
  226. (void)ch_src;
  227. return 0;
  228. }
  229. /**
  230. * @} end of Pooling group
  231. */