arm_convolve_get_buffer_sizes_s8.c 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_convolve_get_buffer_sizes_s8.c
  21. * Description: Collection of get buffer size functions for the various s8 convolution layer functions.
  22. *
  23. * $Date: 27 February 2024
  24. * $Revision: V.2.0.1
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "Internal/arm_nn_compiler.h"
  30. #include "arm_nnfunctions.h"
  31. #include "arm_nnsupportfunctions.h"
  32. /**
  33. * @ingroup NNConv
  34. */
  35. /**
  36. * @addtogroup GetBufferSizeNNConv
  37. * @{
  38. */
  39. __STATIC_INLINE int32_t arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
  40. const cmsis_nn_dims *filter_dims)
  41. {
  42. int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
  43. // Get number of complete int16 lanes(multiple of 8) for given col_length. This is dependent on
  44. // implementation of arm_nn_mat_mult_nt_t_s8
  45. col_length = (col_length + 7) / 8;
  46. // 4 -> number of im2col buffers, 8 -> 8 elements per Q register
  47. return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
  48. }
  49. __STATIC_INLINE int32_t arm_convolve_1_x_n_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
  50. const cmsis_nn_dims *input_dims,
  51. const cmsis_nn_dims *filter_dims,
  52. const cmsis_nn_dims *output_dims)
  53. {
  54. const int32_t input_x = input_dims->w;
  55. const int32_t pad_x = conv_params->padding.w;
  56. const int32_t kernel_x = filter_dims->w;
  57. const int32_t output_x = output_dims->w;
  58. const int32_t stride_x = conv_params->stride.w;
  59. const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
  60. const int32_t asym_pad = total_pad % 2;
  61. const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
  62. const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
  63. const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
  64. if (right_pad_num + no_pad_num + left_pad_num != output_x)
  65. {
  66. return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
  67. }
  68. return 0;
  69. }
  70. int32_t arm_convolve_s8_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
  71. {
  72. #if defined(ARM_MATH_MVEI)
  73. return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
  74. #else
  75. const int32_t rhs_cols = filter_dims->w * filter_dims->h * input_dims->c;
  76. const int32_t remainder = rhs_cols % 4;
  77. const int32_t aligned_rhs_cols = remainder != 0 ? rhs_cols + 4 - remainder : rhs_cols;
  78. return (2 * aligned_rhs_cols) * (int32_t)sizeof(int16_t);
  79. #endif
  80. }
  81. int32_t arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
  82. const cmsis_nn_dims *input_dims,
  83. const cmsis_nn_dims *filter_dims,
  84. const cmsis_nn_dims *output_dims)
  85. {
  86. #if !defined(ARM_MATH_MVEI)
  87. (void)conv_params;
  88. (void)output_dims;
  89. return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
  90. #else
  91. return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
  92. #endif
  93. }
  94. int32_t arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
  95. {
  96. (void)input_dims;
  97. return 0;
  98. }
  99. /*
  100. * Get the required buffer size for arm_convolve_wrapper_s8. This is the recommended function convolve wrapper s8
  101. * function.
  102. *
  103. * Refer to header file for details.
  104. *
  105. */
  106. int32_t arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
  107. const cmsis_nn_dims *input_dims,
  108. const cmsis_nn_dims *filter_dims,
  109. const cmsis_nn_dims *output_dims)
  110. {
  111. #if defined(ARM_MATH_MVEI)
  112. return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
  113. #else
  114. (void)output_dims;
  115. if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
  116. (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
  117. {
  118. if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
  119. {
  120. return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
  121. }
  122. else
  123. {
  124. return 0;
  125. }
  126. }
  127. else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
  128. (conv_params->stride.w * input_dims->c % 4 == 0))
  129. {
  130. return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
  131. }
  132. else
  133. {
  134. return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
  135. }
  136. #endif
  137. }
  138. int32_t arm_convolve_wrapper_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
  139. const cmsis_nn_dims *input_dims,
  140. const cmsis_nn_dims *filter_dims,
  141. const cmsis_nn_dims *output_dims)
  142. {
  143. (void)output_dims;
  144. if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
  145. (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
  146. {
  147. if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
  148. {
  149. return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
  150. }
  151. else
  152. {
  153. return 0;
  154. }
  155. }
  156. else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
  157. (conv_params->stride.w * input_dims->c % 4 == 0))
  158. {
  159. return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
  160. }
  161. else
  162. {
  163. return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
  164. }
  165. }
  166. int32_t arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
  167. const cmsis_nn_dims *input_dims,
  168. const cmsis_nn_dims *filter_dims,
  169. const cmsis_nn_dims *output_dims)
  170. {
  171. return arm_convolve_wrapper_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
  172. }
  173. /**
  174. * @} end of GetBufferSizeNNConv group
  175. */