arm_depthwise_conv_s4_opt.c 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_depthwise_conv_s4_opt.c
  21. * Description: Optimized s4 depthwise separable convolution function for
  22. * channel multiplier of 1.
  23. *
  24. * $Date: 31 October 2023
  25. * $Revision: V.1.0.0
  26. *
  27. * Target : Arm(R) M-Profile Architecture
  28. *
  29. * -------------------------------------------------------------------- */
  30. #include "arm_nnfunctions.h"
  31. #include "arm_nnsupportfunctions.h"
  32. /**
  33. * @ingroup Public
  34. */
  35. /**
  36. * @addtogroup NNConv
  37. * @{
  38. */
  39. /*
  40. * Optimized s4 depthwise convolution function with constraint that in_channel equals out_channel
  41. *
  42. * Refer prototype header file for details.
  43. *
  44. */
  45. arm_cmsis_nn_status arm_depthwise_conv_s4_opt(const cmsis_nn_context *ctx,
  46. const cmsis_nn_dw_conv_params *dw_conv_params,
  47. const cmsis_nn_per_channel_quant_params *quant_params,
  48. const cmsis_nn_dims *input_dims,
  49. const int8_t *input,
  50. const cmsis_nn_dims *filter_dims,
  51. const int8_t *kernel,
  52. const cmsis_nn_dims *bias_dims,
  53. const int32_t *bias,
  54. const cmsis_nn_dims *output_dims,
  55. int8_t *output)
  56. {
  57. (void)bias_dims;
  58. const int32_t input_ch = input_dims->c;
  59. const int32_t output_ch = output_dims->c;
  60. /* Check depth multiplier is 1 */
  61. if (input_ch != output_ch)
  62. {
  63. return ARM_CMSIS_NN_ARG_ERROR;
  64. }
  65. if (ctx->buf == NULL)
  66. {
  67. return ARM_CMSIS_NN_ARG_ERROR;
  68. }
  69. const int32_t input_x = input_dims->w;
  70. const int32_t input_y = input_dims->h;
  71. const int32_t kernel_x = filter_dims->w;
  72. const int32_t kernel_y = filter_dims->h;
  73. const int32_t pad_x = dw_conv_params->padding.w;
  74. const int32_t pad_y = dw_conv_params->padding.h;
  75. const int32_t stride_x = dw_conv_params->stride.w;
  76. const int32_t stride_y = dw_conv_params->stride.h;
  77. const int32_t *output_shift = quant_params->shift;
  78. const int32_t *output_mult = quant_params->multiplier;
  79. const int32_t output_x = output_dims->w;
  80. const int32_t output_y = output_dims->h;
  81. const int32_t output_offset = dw_conv_params->output_offset;
  82. const int32_t input_offset = dw_conv_params->input_offset;
  83. const int32_t output_activation_min = dw_conv_params->activation.min;
  84. const int32_t output_activation_max = dw_conv_params->activation.max;
  85. int16_t *buffer_a = (int16_t *)ctx->buf;
  86. int16_t *const col_buffer_start = buffer_a;
  87. int16_t *col_buffer = col_buffer_start;
  88. const int32_t *const bias_start_pos = bias;
  89. const int32_t *const out_mult_start_pos = output_mult;
  90. const int32_t *const out_shift_start_pos = output_shift;
  91. const uint16_t num_cols = kernel_x * kernel_y;
  92. uint16_t row_count;
  93. uint16_t row_shift = 0;
  94. uint16_t col_shift = 0;
  95. for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
  96. {
  97. const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
  98. for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
  99. {
  100. const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
  101. /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
  102. along the x axis */
  103. const int ker_y_start = MAX(0, -base_idx_y);
  104. /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
  105. const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
  106. int32_t index = 0;
  107. if (ker_y_start != 0)
  108. {
  109. memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
  110. index += (kernel_x * input_ch) * ker_y_start;
  111. }
  112. for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
  113. {
  114. const int32_t idx_y = base_idx_y + i_ker_y;
  115. for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
  116. {
  117. const int32_t idx_x = base_idx_x + i_ker_x;
  118. if (idx_x < 0 || idx_x >= input_x)
  119. {
  120. memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
  121. }
  122. else
  123. {
  124. arm_q7_to_q15_with_offset((int8_t *)input + (idx_y * input_x + idx_x) * input_ch,
  125. &col_buffer[index],
  126. input_ch,
  127. (int16_t)input_offset);
  128. }
  129. index += input_ch;
  130. }
  131. }
  132. const int diff = kernel_y - ker_y_end;
  133. if (diff != 0)
  134. {
  135. memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
  136. }
  137. row_count = output_ch / 4;
  138. row_shift = 0;
  139. col_shift = 0;
  140. bias = bias_start_pos;
  141. output_mult = out_mult_start_pos;
  142. output_shift = out_shift_start_pos;
  143. if (output_ch % 2) /* Uneven number of channels */
  144. {
  145. int get_low_nibble = 1;
  146. while (row_count)
  147. {
  148. int32_t sum = 0;
  149. int32_t sum_2 = 0;
  150. int32_t sum_3 = 0;
  151. int32_t sum_4 = 0;
  152. if (bias)
  153. {
  154. sum = *bias++;
  155. sum_2 = *bias++;
  156. sum_3 = *bias++;
  157. sum_4 = *bias++;
  158. }
  159. uint16_t col_count = num_cols / 2;
  160. int16_t *col_pos = col_buffer_start + col_shift;
  161. const int8_t *row_pos = kernel + row_shift;
  162. row_shift += 2;
  163. col_shift += 4;
  164. while (col_count)
  165. {
  166. #ifdef ARM_MATH_DSP
  167. /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
  168. use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
  169. /* Note: variable names can be improved here to align with rows and columns. */
  170. int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
  171. /* Read 4 weights */
  172. read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
  173. read_and_pad_s4_uneven(row_pos + (input_ch >> 1), &ip_a1, &ip_b2);
  174. op_a = arm_nn_read_s16x2(col_pos);
  175. op_b = arm_nn_read_s16x2(col_pos + input_ch);
  176. op_c = PKHBT(op_b, op_a, 16);
  177. op_a = PKHTB(op_b, op_a, 16);
  178. op_b = PKHBT(ip_b2, ip_a2, 16);
  179. sum = SMLAD(op_c, op_b, sum);
  180. op_b = PKHBT(ip_b1, ip_a1, 16);
  181. sum_2 = SMLAD(op_a, op_b, sum_2);
  182. op_a = arm_nn_read_s16x2(col_pos + 2);
  183. op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
  184. op_c = PKHBT(op_b, op_a, 16);
  185. op_a = PKHTB(op_b, op_a, 16);
  186. op_b = PKHTB(ip_a2, ip_b2, 16);
  187. sum_3 = SMLAD(op_c, op_b, sum_3);
  188. op_b = PKHTB(ip_a1, ip_b1, 16);
  189. sum_4 = SMLAD(op_a, op_b, sum_4);
  190. #else
  191. int8_t ker0, ker1, ker2, ker3, ker00, ker11;
  192. ker00 = row_pos[0];
  193. ker11 = row_pos[1];
  194. ker0 = (int8_t)(ker00 << 4) >> 4;
  195. ker1 = ker00 >> 4;
  196. ker2 = (int8_t)(ker11 << 4) >> 4;
  197. ker3 = ker11 >> 4;
  198. sum += ker0 * col_pos[0];
  199. sum_2 += ker1 * col_pos[1];
  200. sum_3 += ker2 * col_pos[2];
  201. sum_4 += ker3 * col_pos[3];
  202. ker11 = row_pos[1 + (input_ch >> 1)];
  203. ker0 = row_pos[0 + (input_ch >> 1)] >> 4;
  204. ker1 = (int8_t)(ker11 << 4) >> 4;
  205. ker2 = ker11 >> 4;
  206. ker3 = (int8_t)(row_pos[2 + (input_ch >> 1)] << 4) >> 4;
  207. sum += ker0 * col_pos[0 + input_ch];
  208. sum_2 += ker1 * col_pos[1 + input_ch];
  209. sum_3 += ker2 * col_pos[2 + input_ch];
  210. sum_4 += ker3 * col_pos[3 + input_ch];
  211. #endif
  212. row_pos += (input_ch);
  213. col_pos += input_ch << 1;
  214. col_count--;
  215. }
  216. col_count = num_cols & 0x1;
  217. while (col_count)
  218. {
  219. int8_t ker0, ker1, ker2, ker3, ker00, ker11;
  220. ker00 = row_pos[0];
  221. ker11 = row_pos[1];
  222. ker0 = (int8_t)(ker00 << 4) >> 4;
  223. ker1 = ker00 >> 4;
  224. ker2 = (int8_t)(ker11 << 4) >> 4;
  225. ker3 = ker11 >> 4;
  226. sum += ker0 * col_pos[0];
  227. sum_2 += ker1 * col_pos[1];
  228. sum_3 += ker2 * col_pos[2];
  229. sum_4 += ker3 * col_pos[3];
  230. row_pos += input_ch >> 1;
  231. col_pos += input_ch;
  232. col_count--;
  233. }
  234. sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
  235. sum += output_offset;
  236. sum = MAX(sum, output_activation_min);
  237. sum = MIN(sum, output_activation_max);
  238. *output++ = (int8_t)sum;
  239. sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
  240. sum_2 += output_offset;
  241. sum_2 = MAX(sum_2, output_activation_min);
  242. sum_2 = MIN(sum_2, output_activation_max);
  243. *output++ = (int8_t)sum_2;
  244. sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
  245. sum_3 += output_offset;
  246. sum_3 = MAX(sum_3, output_activation_min);
  247. sum_3 = MIN(sum_3, output_activation_max);
  248. *output++ = (int8_t)sum_3;
  249. sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
  250. sum_4 += output_offset;
  251. sum_4 = MAX(sum_4, output_activation_min);
  252. sum_4 = MIN(sum_4, output_activation_max);
  253. *output++ = (int8_t)sum_4;
  254. row_count--;
  255. }
  256. row_count = output_ch & 0x3;
  257. while (row_count)
  258. {
  259. const int16_t *col_pos = col_buffer_start + col_shift;
  260. const int8_t *row_pos = kernel + row_shift;
  261. int32_t sum = 0;
  262. int col_index = 0;
  263. if (bias)
  264. {
  265. sum = *bias++;
  266. }
  267. col_shift += 1;
  268. for (int i = 0; i < num_cols; i++)
  269. {
  270. int8_t rhs = row_pos[i * (input_ch >> 1) + col_index];
  271. int8_t rhs0;
  272. int16_t lhs0 = col_pos[i * input_ch];
  273. if (get_low_nibble)
  274. {
  275. rhs0 = (int8_t)(rhs << 4) >> 4;
  276. get_low_nibble = 0;
  277. }
  278. else
  279. {
  280. rhs0 = rhs >> 4;
  281. get_low_nibble = 1;
  282. col_index++;
  283. }
  284. sum += rhs0 * lhs0;
  285. }
  286. if (num_cols % 2 == 0)
  287. {
  288. get_low_nibble = !get_low_nibble;
  289. }
  290. sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
  291. sum += output_offset;
  292. sum = MAX(sum, output_activation_min);
  293. sum = MIN(sum, output_activation_max);
  294. *output++ = (int8_t)sum;
  295. row_count--;
  296. /* Last row */
  297. if (row_count == 1)
  298. {
  299. row_shift += 1;
  300. }
  301. }
  302. }
  303. else /* Even number of channels */
  304. {
  305. while (row_count)
  306. {
  307. int32_t sum = 0;
  308. int32_t sum_2 = 0;
  309. int32_t sum_3 = 0;
  310. int32_t sum_4 = 0;
  311. if (bias)
  312. {
  313. sum = *bias++;
  314. sum_2 = *bias++;
  315. sum_3 = *bias++;
  316. sum_4 = *bias++;
  317. }
  318. uint16_t col_count = num_cols / 2;
  319. int16_t *col_pos = col_buffer_start + col_shift;
  320. const int8_t *row_pos = kernel + row_shift;
  321. row_shift += 2;
  322. col_shift += 4;
  323. #ifdef ARM_MATH_DSP
  324. while (col_count)
  325. {
  326. /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
  327. use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
  328. /* Note: variable names can be improved here to align with rows and columns. */
  329. int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
  330. /* Read 4 weights */
  331. read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
  332. read_and_pad_s4(row_pos + (input_ch >> 1), &ip_b2, &ip_a1);
  333. op_a = arm_nn_read_s16x2(col_pos);
  334. op_b = arm_nn_read_s16x2(col_pos + input_ch);
  335. op_c = PKHBT(op_b, op_a, 16);
  336. op_a = PKHTB(op_b, op_a, 16);
  337. op_b = PKHBT(ip_b2, ip_a2, 16);
  338. sum = SMLAD(op_c, op_b, sum);
  339. op_b = PKHBT(ip_b1, ip_a1, 16);
  340. sum_2 = SMLAD(op_a, op_b, sum_2);
  341. op_a = arm_nn_read_s16x2(col_pos + 2);
  342. op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
  343. op_c = PKHBT(op_b, op_a, 16);
  344. op_a = PKHTB(op_b, op_a, 16);
  345. op_b = PKHTB(ip_a2, ip_b2, 16);
  346. sum_3 = SMLAD(op_c, op_b, sum_3);
  347. op_b = PKHTB(ip_a1, ip_b1, 16);
  348. sum_4 = SMLAD(op_a, op_b, sum_4);
  349. row_pos += (input_ch);
  350. col_pos += input_ch << 1;
  351. col_count--;
  352. }
  353. col_count = num_cols & 0x1;
  354. #else
  355. col_count = num_cols;
  356. #endif
  357. while (col_count)
  358. {
  359. int8_t ker0, ker1, ker2, ker3, ker00, ker11;
  360. ker00 = row_pos[0];
  361. ker11 = row_pos[1];
  362. ker0 = (int8_t)(ker00 << 4) >> 4;
  363. ker1 = ker00 >> 4;
  364. ker2 = (int8_t)(ker11 << 4) >> 4;
  365. ker3 = ker11 >> 4;
  366. sum += ker0 * col_pos[0];
  367. sum_2 += ker1 * col_pos[1];
  368. sum_3 += ker2 * col_pos[2];
  369. sum_4 += ker3 * col_pos[3];
  370. row_pos += input_ch >> 1;
  371. col_pos += input_ch;
  372. col_count--;
  373. }
  374. sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
  375. sum += output_offset;
  376. sum = MAX(sum, output_activation_min);
  377. sum = MIN(sum, output_activation_max);
  378. *output++ = (int8_t)sum;
  379. sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
  380. sum_2 += output_offset;
  381. sum_2 = MAX(sum_2, output_activation_min);
  382. sum_2 = MIN(sum_2, output_activation_max);
  383. *output++ = (int8_t)sum_2;
  384. sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
  385. sum_3 += output_offset;
  386. sum_3 = MAX(sum_3, output_activation_min);
  387. sum_3 = MIN(sum_3, output_activation_max);
  388. *output++ = (int8_t)sum_3;
  389. sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
  390. sum_4 += output_offset;
  391. sum_4 = MAX(sum_4, output_activation_min);
  392. sum_4 = MIN(sum_4, output_activation_max);
  393. *output++ = (int8_t)sum_4;
  394. row_count--;
  395. }
  396. if (output_ch & 0x2)
  397. {
  398. const int16_t *col_pos = col_buffer_start + col_shift;
  399. const int16_t *col_pos_2 = col_buffer_start + col_shift + 1;
  400. const int8_t *row_pos = kernel + row_shift;
  401. int32_t sum = 0;
  402. int32_t sum2 = 0;
  403. if (bias)
  404. {
  405. sum = *bias++;
  406. sum2 = *bias++;
  407. }
  408. for (int i = 0; i < num_cols; i++)
  409. {
  410. int8_t rhs = row_pos[i * (input_ch >> 1)];
  411. int8_t rhs_low = (int8_t)(rhs << 4) >> 4;
  412. int8_t rhs_high = rhs >> 4;
  413. int16_t lhs0 = col_pos[i * input_ch];
  414. int16_t lhs1 = col_pos_2[i * input_ch];
  415. sum += rhs_low * lhs0;
  416. sum2 += rhs_high * lhs1;
  417. }
  418. sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
  419. sum += output_offset;
  420. sum = MAX(sum, output_activation_min);
  421. sum = MIN(sum, output_activation_max);
  422. *output++ = (int8_t)sum;
  423. sum2 = arm_nn_requantize(sum2, *output_mult++, *output_shift++);
  424. sum2 += output_offset;
  425. sum2 = MAX(sum2, output_activation_min);
  426. sum2 = MIN(sum2, output_activation_max);
  427. *output++ = (int8_t)sum2;
  428. }
  429. }
  430. /* Clear counter and pointers */
  431. col_buffer = col_buffer_start;
  432. }
  433. }
  434. /* Return to application */
  435. return ARM_CMSIS_NN_SUCCESS;
  436. }
  437. /**
  438. * @} end of NNConv group
  439. */