arm_depthwise_conv_s4.c 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_depthwise_conv_s4.c
  21. * Description: s4 version of depthwise convolution.
  22. *
  23. * $Date: 13 February 2024
  24. * $Revision: V.1.1.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup Public
  33. */
  34. /**
  35. * @addtogroup NNConv
  36. * @{
  37. */
  38. static void depthwise_conv_s4_generic(const int8_t *input,
  39. const int32_t input_batches,
  40. const int32_t input_x,
  41. const int32_t input_y,
  42. const int32_t input_ch,
  43. const int8_t *kernel,
  44. const int32_t output_ch,
  45. const int32_t ch_mult,
  46. const int32_t kernel_x,
  47. const int32_t kernel_y,
  48. const int32_t pad_x,
  49. const int32_t pad_y,
  50. const int32_t stride_x,
  51. const int32_t stride_y,
  52. const int32_t *bias,
  53. int8_t *output,
  54. const int32_t *output_shift,
  55. const int32_t *output_mult,
  56. const int32_t output_x,
  57. const int32_t output_y,
  58. const int32_t output_offset,
  59. const int32_t input_offset,
  60. const int32_t output_activation_min,
  61. const int32_t output_activation_max,
  62. const int32_t dilation_x,
  63. const int32_t dilation_y)
  64. {
  65. (void)output_ch;
  66. int i_out = 0;
  67. int i_batch;
  68. const int32_t kernel_index_offset = input_ch >> 1;
  69. if (!(input_ch % 2))
  70. {
  71. for (i_batch = 0; i_batch < input_batches; i_batch++)
  72. {
  73. for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
  74. {
  75. const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
  76. for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
  77. {
  78. const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
  79. int idx_out_ch_s4 = 0;
  80. int get_low_nibble = 1;
  81. // If ch_mult is 1 we can process 2 outputs at a time by doing 2 input_ch iterations
  82. if (ch_mult == 1)
  83. {
  84. for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch += 2, idx_out_ch_s4++)
  85. {
  86. int32_t acc_0 = 0;
  87. int32_t acc_1 = 0;
  88. int ker_y_start;
  89. int ker_x_start;
  90. int ker_y_end;
  91. int ker_x_end;
  92. if (dilation_x > 1)
  93. {
  94. const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
  95. ker_x_start = MAX(0, start_x_max);
  96. const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
  97. ker_x_end = MIN(kernel_x, end_min_x);
  98. }
  99. else
  100. {
  101. ker_x_start = MAX(0, -base_idx_x);
  102. ker_x_end = MIN(kernel_x, input_x - base_idx_x);
  103. }
  104. if (dilation_y > 1)
  105. {
  106. const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
  107. ker_y_start = MAX(0, start_y_max);
  108. const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
  109. ker_y_end = MIN(kernel_y, end_min_y);
  110. }
  111. else
  112. {
  113. ker_y_start = MAX(0, -base_idx_y);
  114. ker_y_end = MIN(kernel_y, input_y - base_idx_y);
  115. }
  116. if (bias)
  117. {
  118. acc_0 = bias[i_input_ch];
  119. acc_1 = bias[i_input_ch + 1];
  120. }
  121. int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
  122. for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
  123. {
  124. int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
  125. int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
  126. int32_t ker_idx_0 =
  127. (i_ker_y * kernel_x + ker_x_start) * kernel_index_offset + idx_out_ch_s4;
  128. for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
  129. {
  130. int8_t ker_val0, ker_val1;
  131. ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
  132. ker_val1 = (kernel[ker_idx_0] >> 4);
  133. acc_0 += (input[idx_0] + input_offset) * ker_val0;
  134. acc_1 += (input[idx_0 + 1] + input_offset) * ker_val1;
  135. idx_0 += dilation_x * input_ch;
  136. idx_x += dilation_x;
  137. ker_idx_0 += kernel_index_offset;
  138. }
  139. idx_y += dilation_y;
  140. }
  141. /* Requantize and clamp output to provided range */
  142. acc_0 = arm_nn_requantize(acc_0, output_mult[i_input_ch], output_shift[i_input_ch]);
  143. acc_0 += output_offset;
  144. acc_0 = MAX(acc_0, output_activation_min);
  145. acc_0 = MIN(acc_0, output_activation_max);
  146. output[i_out++] = acc_0;
  147. acc_1 = arm_nn_requantize(acc_1, output_mult[i_input_ch + 1], output_shift[i_input_ch + 1]);
  148. acc_1 += output_offset;
  149. acc_1 = MAX(acc_1, output_activation_min);
  150. acc_1 = MIN(acc_1, output_activation_max);
  151. output[i_out++] = acc_1;
  152. }
  153. }
  154. // if ch_mult is odd and greater than 1, we need to continue to process 1 output at a time
  155. else if (ch_mult % 2)
  156. {
  157. for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
  158. {
  159. for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
  160. {
  161. const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
  162. if (idx_out_ch && (idx_out_ch % 2 == 0))
  163. {
  164. idx_out_ch_s4++;
  165. }
  166. int32_t acc_0 = 0;
  167. int ker_y_start;
  168. int ker_x_start;
  169. int ker_y_end;
  170. int ker_x_end;
  171. if (dilation_x > 1)
  172. {
  173. const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
  174. ker_x_start = MAX(0, start_x_max);
  175. const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
  176. ker_x_end = MIN(kernel_x, end_min_x);
  177. }
  178. else
  179. {
  180. ker_x_start = MAX(0, -base_idx_x);
  181. ker_x_end = MIN(kernel_x, input_x - base_idx_x);
  182. }
  183. if (dilation_y > 1)
  184. {
  185. const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
  186. ker_y_start = MAX(0, start_y_max);
  187. const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
  188. ker_y_end = MIN(kernel_y, end_min_y);
  189. }
  190. else
  191. {
  192. ker_y_start = MAX(0, -base_idx_y);
  193. ker_y_end = MIN(kernel_y, input_y - base_idx_y);
  194. }
  195. if (bias)
  196. {
  197. acc_0 = bias[idx_out_ch];
  198. }
  199. int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
  200. for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
  201. {
  202. int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
  203. int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
  204. int32_t ker_idx_0 =
  205. (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
  206. idx_out_ch_s4;
  207. for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
  208. {
  209. int8_t ker_val0;
  210. if (get_low_nibble)
  211. {
  212. ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
  213. }
  214. else
  215. {
  216. ker_val0 = (kernel[ker_idx_0] >> 4);
  217. }
  218. acc_0 += (input[idx_0] + input_offset) * ker_val0;
  219. idx_0 += dilation_x * input_ch;
  220. idx_x += dilation_x;
  221. ker_idx_0 += (kernel_index_offset * ch_mult);
  222. }
  223. idx_y += dilation_y;
  224. }
  225. get_low_nibble = !get_low_nibble;
  226. /* Requantize and clamp output to provided range */
  227. acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
  228. acc_0 += output_offset;
  229. acc_0 = MAX(acc_0, output_activation_min);
  230. acc_0 = MIN(acc_0, output_activation_max);
  231. output[i_out++] = acc_0;
  232. }
  233. }
  234. }
  235. // if ch_mult is even then we can do 2 outputs at a time by processing 2 ch_mult iterations
  236. else
  237. {
  238. for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
  239. {
  240. // ch_mult is limited to being a multiple of input_ch.
  241. // This means that we can assume ch_mult is a multiple of 2 given that input_ch is even
  242. for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult += 2, idx_out_ch_s4++)
  243. {
  244. const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
  245. int32_t acc_0 = 0;
  246. int32_t acc_1 = 0;
  247. int ker_y_start;
  248. int ker_x_start;
  249. int ker_y_end;
  250. int ker_x_end;
  251. if (dilation_x > 1)
  252. {
  253. const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
  254. ker_x_start = MAX(0, start_x_max);
  255. const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
  256. ker_x_end = MIN(kernel_x, end_min_x);
  257. }
  258. else
  259. {
  260. ker_x_start = MAX(0, -base_idx_x);
  261. ker_x_end = MIN(kernel_x, input_x - base_idx_x);
  262. }
  263. if (dilation_y > 1)
  264. {
  265. const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
  266. ker_y_start = MAX(0, start_y_max);
  267. const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
  268. ker_y_end = MIN(kernel_y, end_min_y);
  269. }
  270. else
  271. {
  272. ker_y_start = MAX(0, -base_idx_y);
  273. ker_y_end = MIN(kernel_y, input_y - base_idx_y);
  274. }
  275. if (bias)
  276. {
  277. acc_0 = bias[idx_out_ch];
  278. acc_1 = bias[idx_out_ch + 1];
  279. }
  280. int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
  281. for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
  282. {
  283. int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
  284. int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
  285. int32_t ker_idx_0 =
  286. (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
  287. idx_out_ch_s4;
  288. for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
  289. {
  290. int8_t ker_val0, ker_val1;
  291. ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
  292. ker_val1 = (kernel[ker_idx_0] >> 4);
  293. acc_0 += (input[idx_0] + input_offset) * ker_val0;
  294. acc_1 += (input[idx_0] + input_offset) * ker_val1;
  295. idx_0 += dilation_x * input_ch;
  296. idx_x += dilation_x;
  297. ker_idx_0 += (kernel_index_offset * ch_mult);
  298. }
  299. idx_y += dilation_y;
  300. }
  301. /* Requantize and clamp output to provided range */
  302. acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
  303. acc_0 += output_offset;
  304. acc_0 = MAX(acc_0, output_activation_min);
  305. acc_0 = MIN(acc_0, output_activation_max);
  306. output[i_out++] = acc_0;
  307. acc_1 =
  308. arm_nn_requantize(acc_1, output_mult[idx_out_ch + 1], output_shift[idx_out_ch + 1]);
  309. acc_1 += output_offset;
  310. acc_1 = MAX(acc_1, output_activation_min);
  311. acc_1 = MIN(acc_1, output_activation_max);
  312. output[i_out++] = acc_1;
  313. }
  314. }
  315. }
  316. }
  317. }
  318. /* Advance to the next batch */
  319. input += (input_x * input_y * input_ch);
  320. }
  321. }
  322. else
  323. {
  324. for (i_batch = 0; i_batch < input_batches; i_batch++)
  325. {
  326. for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
  327. {
  328. const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
  329. for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
  330. {
  331. const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
  332. int idx_out_ch_s4 = 0;
  333. int get_low_nibble = 1;
  334. for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
  335. {
  336. for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
  337. {
  338. const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
  339. if (idx_out_ch && (idx_out_ch % 2 == 0))
  340. {
  341. idx_out_ch_s4++;
  342. }
  343. int16_t kernel_index_offset_uneven = 0;
  344. int32_t acc_0 = 0;
  345. int ker_y_start;
  346. int ker_x_start;
  347. int ker_y_end;
  348. int ker_x_end;
  349. if (dilation_x > 1)
  350. {
  351. const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
  352. ker_x_start = MAX(0, start_x_max);
  353. const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
  354. ker_x_end = MIN(kernel_x, end_min_x);
  355. }
  356. else
  357. {
  358. ker_x_start = MAX(0, -base_idx_x);
  359. ker_x_end = MIN(kernel_x, input_x - base_idx_x);
  360. }
  361. if (dilation_y > 1)
  362. {
  363. const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
  364. ker_y_start = MAX(0, start_y_max);
  365. const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
  366. ker_y_end = MIN(kernel_y, end_min_y);
  367. }
  368. else
  369. {
  370. ker_y_start = MAX(0, -base_idx_y);
  371. ker_y_end = MIN(kernel_y, input_y - base_idx_y);
  372. }
  373. if (bias)
  374. {
  375. acc_0 = bias[idx_out_ch];
  376. }
  377. int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
  378. for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
  379. {
  380. int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
  381. int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
  382. int32_t ker_idx_0 =
  383. (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
  384. idx_out_ch_s4 + kernel_index_offset_uneven;
  385. for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
  386. {
  387. int8_t ker_val;
  388. if (get_low_nibble)
  389. {
  390. get_low_nibble = 0;
  391. ker_val = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
  392. }
  393. else
  394. {
  395. ker_val = (kernel[ker_idx_0] >> 4);
  396. get_low_nibble = 1;
  397. kernel_index_offset_uneven++;
  398. }
  399. acc_0 += (input[idx_0] + input_offset) * ker_val;
  400. idx_0 += dilation_x * input_ch;
  401. idx_x += dilation_x;
  402. ker_idx_0 += (kernel_index_offset * ch_mult) + get_low_nibble;
  403. }
  404. idx_y += dilation_y;
  405. }
  406. if ((kernel_x * kernel_y) % 2)
  407. {
  408. get_low_nibble = !get_low_nibble;
  409. }
  410. get_low_nibble = !get_low_nibble;
  411. /* Requantize and clamp output to provided range */
  412. acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
  413. acc_0 += output_offset;
  414. acc_0 = MAX(acc_0, output_activation_min);
  415. acc_0 = MIN(acc_0, output_activation_max);
  416. output[i_out++] = acc_0;
  417. }
  418. }
  419. }
  420. }
  421. /* Advance to the next batch */
  422. input += (input_x * input_y * input_ch);
  423. }
  424. }
  425. }
  426. /*
  427. * Basic s4 depthwise convolution function.
  428. *
  429. * Refer header file for details.
  430. * Optimization using DSP extension is not available for the generic case where channel multiplier is > 1.
  431. *
  432. */
  433. arm_cmsis_nn_status arm_depthwise_conv_s4(const cmsis_nn_context *ctx,
  434. const cmsis_nn_dw_conv_params *dw_conv_params,
  435. const cmsis_nn_per_channel_quant_params *quant_params,
  436. const cmsis_nn_dims *input_dims,
  437. const int8_t *input,
  438. const cmsis_nn_dims *filter_dims,
  439. const int8_t *kernel,
  440. const cmsis_nn_dims *bias_dims,
  441. const int32_t *bias,
  442. const cmsis_nn_dims *output_dims,
  443. int8_t *output)
  444. {
  445. (void)bias_dims;
  446. (void)ctx;
  447. const int32_t dilation_x = dw_conv_params->dilation.w;
  448. const int32_t dilation_y = dw_conv_params->dilation.h;
  449. depthwise_conv_s4_generic(input,
  450. input_dims->n,
  451. input_dims->w,
  452. input_dims->h,
  453. input_dims->c,
  454. kernel,
  455. output_dims->c,
  456. dw_conv_params->ch_mult,
  457. filter_dims->w,
  458. filter_dims->h,
  459. dw_conv_params->padding.w,
  460. dw_conv_params->padding.h,
  461. dw_conv_params->stride.w,
  462. dw_conv_params->stride.h,
  463. bias,
  464. output,
  465. quant_params->shift,
  466. quant_params->multiplier,
  467. output_dims->w,
  468. output_dims->h,
  469. dw_conv_params->output_offset,
  470. dw_conv_params->input_offset,
  471. dw_conv_params->activation.min,
  472. dw_conv_params->activation.max,
  473. dilation_x,
  474. dilation_y);
  475. /* Return to application */
  476. return ARM_CMSIS_NN_SUCCESS;
  477. }
  478. /**
  479. * @} end of NNConv group
  480. */