/* * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library * Title: arm_nn_mat_mult_kernel_s8_s16.c * Description: Matrix-multiplication function for convolution * * $Date: 29 May 2023 * $Revision: V.2.0.0 * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ #include "arm_nnfunctions.h" #include "arm_nnsupportfunctions.h" /* * Matrix-multiplication function for convolution with per-channel requantization. * * Refer header file for details. * */ int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a, const int16_t *input_b, const uint16_t output_ch, const int32_t *out_shift, const int32_t *out_mult, const int32_t out_offset, const int16_t activation_min, const int16_t activation_max, const int32_t num_col_a, const int32_t aligned_num_col_a, const int32_t *const output_bias, int8_t *out_0) { #if !defined(ARM_MATH_MVEI) /* set up the second output pointers */ int8_t *out_1 = out_0 + output_ch; const int32_t *bias = output_bias; uint16_t row_count = output_ch / 2; const int8_t *ip_a0 = input_a; /* this loop over rows in A */ while (row_count) { /* setup pointers for B */ const int16_t *ip_b0 = input_b; const int16_t *ip_b1 = ip_b0 + aligned_num_col_a; /* align the second pointer for A */ const int8_t *ip_a1 = ip_a0 + num_col_a; int32_t ch_0_out_0 = 0; int32_t ch_0_out_1 = 0; int32_t ch_1_out_0 = 0; int32_t ch_1_out_1 = 0; /* Init accumulator with bias for channel N and N + 1 */ if (bias) { ch_0_out_0 = *bias; ch_0_out_1 = *bias++; ch_1_out_0 = *bias; ch_1_out_1 = *bias++; } #if defined(ARM_MATH_DSP) int32_t col_count = num_col_a / 4; /* accumulate over the vector */ while (col_count) { int32_t a01, a02, a11, a12; int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0); int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1); ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02); ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12); ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0); ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1); ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0); ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1); b0 = arm_nn_read_q15x2_ia(&ip_b0); b1 = arm_nn_read_q15x2_ia(&ip_b1); ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0); ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1); ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0); ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1); col_count--; } /* while over col_count */ col_count = num_col_a & 0x3; #else int32_t col_count = num_col_a; #endif while (col_count) { int8_t a0 = *ip_a0++; int16_t b0 = *ip_b0++; int8_t a1 = *ip_a1++; int16_t b1 = *ip_b1++; ch_0_out_0 += a0 * b0; ch_0_out_1 += a0 * b1; ch_1_out_0 += a1 * b0; ch_1_out_1 += a1 * b1; col_count--; } /* while over col_count */ ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift); ch_0_out_0 += out_offset; ch_0_out_0 = MAX(ch_0_out_0, activation_min); ch_0_out_0 = MIN(ch_0_out_0, activation_max); *out_0++ = (int8_t)ch_0_out_0; ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift); ch_0_out_1 += out_offset; ch_0_out_1 = MAX(ch_0_out_1, activation_min); ch_0_out_1 = MIN(ch_0_out_1, activation_max); *out_1++ = (int8_t)ch_0_out_1; out_mult++; out_shift++; ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift); ch_1_out_0 += out_offset; ch_1_out_0 = MAX(ch_1_out_0, activation_min); ch_1_out_0 = MIN(ch_1_out_0, activation_max); *out_0++ = (int8_t)ch_1_out_0; ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift); ch_1_out_1 += out_offset; ch_1_out_1 = MAX(ch_1_out_1, activation_min); ch_1_out_1 = MIN(ch_1_out_1, activation_max); *out_1++ = (int8_t)ch_1_out_1; out_mult++; out_shift++; /* skip row */ ip_a0 += num_col_a; row_count--; } /* compute the last odd numbered row if any */ if (output_ch & 0x1) { /* setup pointers for B */ const int16_t *ip_b0 = input_b; const int16_t *ip_b1 = ip_b0 + aligned_num_col_a; int32_t ch_0_out_0 = 0; int32_t ch_0_out_1 = 0; /* load the bias */ if (bias) { ch_0_out_0 = *bias; ch_0_out_1 = *bias++; } #if defined(ARM_MATH_DSP) int32_t col_count = num_col_a >> 2; while (col_count) { int32_t a01, a02; int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0); int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1); ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02); ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0); ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1); b0 = arm_nn_read_q15x2_ia(&ip_b0); b1 = arm_nn_read_q15x2_ia(&ip_b1); ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0); ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1); col_count--; } col_count = num_col_a & 0x3; #else int32_t col_count = num_col_a; #endif while (col_count) { int8_t a0 = *ip_a0++; int16_t b0 = *ip_b0++; int16_t b1 = *ip_b1++; ch_0_out_0 += a0 * b0; ch_0_out_1 += a0 * b1; col_count--; } ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift); ch_0_out_0 += out_offset; ch_0_out_0 = MAX(ch_0_out_0, activation_min); ch_0_out_0 = MIN(ch_0_out_0, activation_max); *out_0++ = (int8_t)ch_0_out_0; ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift); ch_0_out_1 += out_offset; ch_0_out_1 = MAX(ch_0_out_1, activation_min); ch_0_out_1 = MIN(ch_0_out_1, activation_max); *out_1++ = (int8_t)ch_0_out_1; out_mult++; out_shift++; } out_0 += output_ch; /* return the new output pointer with offset */ return out_0; #else (void)input_a; (void)input_b; (void)output_ch; (void)out_shift; (void)out_mult; (void)out_offset; (void)activation_min; (void)activation_max; (void)aligned_num_col_a, (void)num_col_a; (void)output_bias; (void)out_0; /* To be completed */ return NULL; #endif }