| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- /*
- * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_nn_mat_mult_kernel_s4_s16.c
- * Description: Matrix-multiplication function for convolution
- *
- * $Date: 01 November 2023
- * $Revision: V.1.0.0
- *
- * Target : Arm(R) M-Profile Architecture
- * -------------------------------------------------------------------- */
- #include "arm_nnsupportfunctions.h"
- /*
- * Matrix-multiplication function for convolution with per-channel requantization and 4bit weights.
- *
- * Refer header file for details.
- *
- */
- int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *packed_input_a,
- const int16_t *input_b,
- const uint16_t output_ch,
- const int32_t *out_shift,
- const int32_t *out_mult,
- const int32_t out_offset,
- const int32_t activation_min,
- const int32_t activation_max,
- const int32_t num_col_a,
- const int32_t *const output_bias,
- int8_t *out_0)
- {
- /* set up the second output pointers */
- int8_t *out_1 = out_0 + output_ch;
- const int32_t *bias = output_bias;
- uint16_t row_count = output_ch / 4;
- const int8_t *packed_ip_a0 = packed_input_a;
- /* this loop over rows in A */
- while (row_count)
- {
- int8_t spillover0 = 0;
- int8_t spillover1 = 0;
- /* setup pointers for B */
- const int16_t *ip_b0 = input_b;
- const int16_t *ip_b1 = ip_b0 + num_col_a;
- /* Align the second pointer for A.
- * This will skip a row so that we can ensure the that spilled rows
- * don't offset the symmetry.
- */
- const int8_t *packed_ip_a1 = packed_ip_a0 + num_col_a;
- int32_t ch_0_out_0 = 0;
- int32_t ch_0_out_1 = 0;
- int32_t ch_1_out_0 = 0;
- int32_t ch_1_out_1 = 0;
- /* Init accumulator with bias for channel N and N + 1 */
- if (bias)
- {
- ch_0_out_0 = *bias;
- ch_0_out_1 = *bias;
- bias += 2;
- ch_1_out_0 = *bias;
- ch_1_out_1 = *bias--;
- }
- #if defined(ARM_MATH_DSP)
- int32_t col_count = num_col_a / 4;
- /* accumulate over the vector */
- while (col_count)
- {
- int32_t a01, a02, a11, a12;
- int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
- int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
- read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
- read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
- packed_ip_a0 += 2;
- packed_ip_a1 += 2;
- ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
- ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
- ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
- b0 = arm_nn_read_q15x2_ia(&ip_b0);
- b1 = arm_nn_read_q15x2_ia(&ip_b1);
- ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
- ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
- ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
- col_count--;
- } /* while over col_count */
- col_count = (num_col_a & 0x3) >> 1;
- #else
- int32_t col_count = num_col_a >> 1;
- #endif
- while (col_count)
- {
- int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
- int8_t higher_a0 = packed_ip_a0[0] >> 4;
- int16_t b0 = *ip_b0++;
- int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
- int8_t higher_a1 = packed_ip_a1[0] >> 4;
- int16_t b1 = *ip_b1++;
- packed_ip_a0++;
- packed_ip_a1++;
- ch_0_out_0 += lower_a0 * b0;
- ch_0_out_1 += lower_a0 * b1;
- ch_1_out_0 += lower_a1 * b0;
- ch_1_out_1 += lower_a1 * b1;
- b0 = *ip_b0++;
- b1 = *ip_b1++;
- ch_0_out_0 += higher_a0 * b0;
- ch_0_out_1 += higher_a0 * b1;
- ch_1_out_0 += higher_a1 * b0;
- ch_1_out_1 += higher_a1 * b1;
- col_count--;
- } /* while over col_count */
- /* left over column */
- if (num_col_a % 2)
- {
- int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
- spillover0 = packed_ip_a0[0] >> 4;
- int16_t b0 = *ip_b0++;
- int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
- spillover1 = packed_ip_a1[0] >> 4;
- int16_t b1 = *ip_b1++;
- packed_ip_a0++;
- packed_ip_a1++;
- ch_0_out_0 += lower_a0 * b0;
- ch_0_out_1 += lower_a0 * b1;
- ch_1_out_0 += lower_a1 * b0;
- ch_1_out_1 += lower_a1 * b1;
- }
- ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
- ch_0_out_0 += out_offset;
- ch_0_out_0 = MAX(ch_0_out_0, activation_min);
- ch_0_out_0 = MIN(ch_0_out_0, activation_max);
- *out_0 = (int8_t)ch_0_out_0;
- out_0 += 2;
- ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
- ch_0_out_1 += out_offset;
- ch_0_out_1 = MAX(ch_0_out_1, activation_min);
- ch_0_out_1 = MIN(ch_0_out_1, activation_max);
- *out_1 = (int8_t)ch_0_out_1;
- out_1 += 2;
- out_mult += 2;
- out_shift += 2;
- ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
- ch_1_out_0 += out_offset;
- ch_1_out_0 = MAX(ch_1_out_0, activation_min);
- ch_1_out_0 = MIN(ch_1_out_0, activation_max);
- *out_0-- = (int8_t)ch_1_out_0;
- ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
- ch_1_out_1 += out_offset;
- ch_1_out_1 = MAX(ch_1_out_1, activation_min);
- ch_1_out_1 = MIN(ch_1_out_1, activation_max);
- *out_1-- = (int8_t)ch_1_out_1;
- out_mult--;
- out_shift--;
- /* setup pointers for B */
- ip_b0 = input_b;
- ip_b1 = ip_b0 + num_col_a;
- /* Align the second pointer for A.
- * This will skip a row so that we can ensure the that spilled rows
- * don't offset the symmetry.
- */
- packed_ip_a1 = packed_ip_a0 + num_col_a;
- ch_0_out_0 = 0;
- ch_0_out_1 = 0;
- ch_1_out_0 = 0;
- ch_1_out_1 = 0;
- /* Init accumulator with bias for channel N and N + 1 */
- if (bias)
- {
- ch_0_out_0 = *bias;
- ch_0_out_1 = *bias;
- bias += 2;
- ch_1_out_0 = *bias;
- ch_1_out_1 = *bias++;
- }
- if (num_col_a % 2)
- {
- int16_t b0 = *ip_b0++;
- int16_t b1 = *ip_b1++;
- ch_0_out_0 += spillover0 * b0;
- ch_0_out_1 += spillover0 * b1;
- ch_1_out_0 += spillover1 * b0;
- ch_1_out_1 += spillover1 * b1;
- }
- #if defined(ARM_MATH_DSP)
- col_count = num_col_a / 4;
- /* accumulate over the vector */
- while (col_count)
- {
- int32_t a01, a02, a11, a12;
- int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
- int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
- read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
- read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
- packed_ip_a0 += 2;
- packed_ip_a1 += 2;
- ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
- ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
- ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
- b0 = arm_nn_read_q15x2_ia(&ip_b0);
- b1 = arm_nn_read_q15x2_ia(&ip_b1);
- ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
- ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
- ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
- col_count--;
- } /* while over col_count */
- col_count = (num_col_a & 0x3) >> 1;
- #else
- col_count = num_col_a >> 1;
- #endif
- while (col_count)
- {
- int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
- int8_t higher_a0 = packed_ip_a0[0] >> 4;
- int16_t b0 = *ip_b0++;
- int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
- int8_t higher_a1 = packed_ip_a1[0] >> 4;
- int16_t b1 = *ip_b1++;
- packed_ip_a0++;
- packed_ip_a1++;
- ch_0_out_0 += lower_a0 * b0;
- ch_0_out_1 += lower_a0 * b1;
- ch_1_out_0 += lower_a1 * b0;
- ch_1_out_1 += lower_a1 * b1;
- b0 = *ip_b0++;
- b1 = *ip_b1++;
- ch_0_out_0 += higher_a0 * b0;
- ch_0_out_1 += higher_a0 * b1;
- ch_1_out_0 += higher_a1 * b0;
- ch_1_out_1 += higher_a1 * b1;
- col_count--;
- } /* while over col_count */
- ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
- ch_0_out_0 += out_offset;
- ch_0_out_0 = MAX(ch_0_out_0, activation_min);
- ch_0_out_0 = MIN(ch_0_out_0, activation_max);
- *out_0 = (int8_t)ch_0_out_0;
- out_0 += 2;
- ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
- ch_0_out_1 += out_offset;
- ch_0_out_1 = MAX(ch_0_out_1, activation_min);
- ch_0_out_1 = MIN(ch_0_out_1, activation_max);
- *out_1 = (int8_t)ch_0_out_1;
- out_1 += 2;
- out_mult += 2;
- out_shift += 2;
- ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
- ch_1_out_0 += out_offset;
- ch_1_out_0 = MAX(ch_1_out_0, activation_min);
- ch_1_out_0 = MIN(ch_1_out_0, activation_max);
- *out_0++ = (int8_t)ch_1_out_0;
- ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
- ch_1_out_1 += out_offset;
- ch_1_out_1 = MAX(ch_1_out_1, activation_min);
- ch_1_out_1 = MIN(ch_1_out_1, activation_max);
- *out_1++ = (int8_t)ch_1_out_1;
- out_mult++;
- out_shift++;
- /* skip 2 rows */
- packed_ip_a0 += num_col_a;
- row_count--;
- }
- /* compute the 0 - 3 rows if any */
- int16_t left_over_rows = 0;
- while (left_over_rows < output_ch % 4)
- {
- /* setup pointers for B */
- const int16_t *ip_b0 = input_b;
- const int16_t *ip_b1 = ip_b0 + num_col_a;
- int32_t ch_0_out_0 = 0;
- int32_t ch_0_out_1 = 0;
- /* load the bias */
- if (bias)
- {
- ch_0_out_0 = *bias;
- ch_0_out_1 = *bias++;
- }
- if (left_over_rows == 1 && num_col_a % 2)
- {
- int16_t b0 = *ip_b0++;
- int16_t b1 = *ip_b1++;
- int8_t spilled_column = packed_ip_a0[0] >> 4;
- ++packed_ip_a0;
- ch_0_out_0 += spilled_column * b0;
- ch_0_out_1 += spilled_column * b1;
- }
- #if defined(ARM_MATH_DSP)
- int32_t col_count = num_col_a / 4;
- while (col_count)
- {
- int32_t a01, a02;
- int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
- int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
- read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
- packed_ip_a0 += 2;
- ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
- b0 = arm_nn_read_q15x2_ia(&ip_b0);
- b1 = arm_nn_read_q15x2_ia(&ip_b1);
- ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
- ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
- col_count--;
- }
- col_count = (num_col_a & 0x3) >> 1;
- #else
- int32_t col_count = num_col_a >> 1;
- #endif
- while (col_count)
- {
- int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
- int8_t a1 = packed_ip_a0[0] >> 4;
- int16_t b0 = *ip_b0++;
- int16_t b1 = *ip_b1++;
- ++packed_ip_a0;
- ch_0_out_0 += a0 * b0;
- ch_0_out_1 += a0 * b1;
- b0 = *ip_b0++;
- b1 = *ip_b1++;
- ch_0_out_0 += a1 * b0;
- ch_0_out_1 += a1 * b1;
- col_count--;
- }
- if (num_col_a % 2 && left_over_rows != 1)
- {
- int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
- int16_t b0 = *ip_b0++;
- int16_t b1 = *ip_b1++;
- ch_0_out_0 += a0 * b0;
- ch_0_out_1 += a0 * b1;
- }
- ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
- ch_0_out_0 += out_offset;
- ch_0_out_0 = MAX(ch_0_out_0, activation_min);
- ch_0_out_0 = MIN(ch_0_out_0, activation_max);
- *out_0++ = (int8_t)ch_0_out_0;
- ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
- ch_0_out_1 += out_offset;
- ch_0_out_1 = MAX(ch_0_out_1, activation_min);
- ch_0_out_1 = MIN(ch_0_out_1, activation_max);
- *out_1++ = (int8_t)ch_0_out_1;
- out_mult++;
- out_shift++;
- ++left_over_rows;
- }
- out_0 += output_ch;
- /* return the new output pointer with offset */
- return out_0;
- }
|