| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801 |
- /*
- * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_nn_mat_mult_s8_nt_t_s8
- * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
- *
- * $Date: 04 January 2024
- * $Revision: V.3.0.0
- *
- * Target : Arm(R) M-Profile Architecture
- *
- * -------------------------------------------------------------------- */
- #include "arm_nnsupportfunctions.h"
- /**
- * @ingroup groupSupport
- */
- /**
- * @addtogroup supportConvolution
- * @{
- */
- /*
- * s8 matrix multiplication with the right-hand-side matrix transposed
- *
- * Refer header file for details.
- *
- */
- arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
- const int8_t *rhs,
- const int32_t *bias,
- int8_t *dst,
- const int32_t *dst_multipliers,
- const int32_t *dst_shifts,
- const int32_t lhs_rows,
- const int32_t rhs_rows,
- const int32_t rhs_cols,
- const int32_t lhs_offset,
- const int32_t dst_offset,
- const int32_t activation_min,
- const int32_t activation_max,
- const int32_t row_address_offset,
- const int32_t lhs_cols_offset)
- {
- #if defined(ARM_MATH_MVEI)
- int i_items = 0;
- for (; i_items <= (lhs_rows - 4); i_items += 4)
- {
- for (int i = 0; i < rhs_rows; i++)
- {
- int32_t acc_n0 = 0;
- int32_t acc_n1 = 0;
- int32_t acc_n2 = 0;
- int32_t acc_n3 = 0;
- const int8_t *lhs_vec = lhs;
- const int8_t *ip_row_1 = lhs + lhs_cols_offset;
- const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
- const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
- const int8_t *col_base = rhs + i * rhs_cols;
- int32_t sum_tmp = 0;
- #if defined(ARM_MATH_AUTOVECTORIZE)
- for (int j = 0; j < rhs_cols; j++)
- {
- int32_t col = col_base[j];
- sum_tmp += col;
- acc_n0 += lhs_vec[j] * col;
- acc_n1 += ip_row_1[j] * col;
- acc_n2 += ip_row_2[j] * col;
- acc_n3 += ip_row_3[j] * col;
- }
- #else
- // Note: If operand initialization is moved around, use '&' constraint to
- // specify earlyclobber operands.
- __ASM volatile(" .p2align 2 \n"
- " wlstp.8 lr, %[cnt], 1f \n"
- " mov %[sum], 0 \n"
- " mov %[out0], 0 \n"
- " mov %[out1], 0 \n"
- " mov %[out2], 0 \n"
- " mov %[out3], 0 \n"
- " vldrb.8 q0, [%[col]], #16 \n"
- "2: \n"
- " vaddva.s8 %[sum], q0 \n"
- " vldrb.8 q1, [%[row0]], #16 \n"
- " vmladava.s8 %[out0], q0, q1 \n"
- " vldrb.8 q2, [%[row1]], #16 \n"
- " vmladava.s8 %[out1], q0, q2 \n"
- " vldrb.8 q3, [%[row2]], #16 \n"
- " vmladava.s8 %[out2], q0, q3 \n"
- " vldrb.8 q4, [%[row3]], #16 \n"
- " vmladava.s8 %[out3], q0, q4 \n"
- " vldrb.8 q0, [%[col]], #16 \n"
- " letp lr, 2b \n"
- "1: \n"
- : [col] "+r"(col_base),
- [sum] "=Te"(sum_tmp),
- [row0] "+r"(lhs_vec),
- [row1] "+r"(ip_row_1),
- [row2] "+r"(ip_row_2),
- [row3] "+r"(ip_row_3),
- [out0] "=Te"(acc_n0),
- [out1] "=Te"(acc_n1),
- [out2] "=Te"(acc_n2),
- [out3] "=Te"(acc_n3)
- : [cnt] "r"(rhs_cols)
- : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
- #endif
- int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
- sum_tmp *= lhs_offset;
- if (bias)
- {
- sum_tmp += bias[i];
- }
- res = vaddq_n_s32(res, sum_tmp);
- res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
- res = vaddq_n_s32(res, dst_offset);
- res = vmaxq_s32(res, vdupq_n_s32(activation_min));
- res = vminq_s32(res, vdupq_n_s32(activation_max));
- const uint32x4_t scatter_offset = {
- 0, (uint32_t)row_address_offset, (uint32_t)row_address_offset * 2, (uint32_t)row_address_offset * 3};
- vstrbq_scatter_offset_s32(dst, scatter_offset, res);
- dst++;
- }
- lhs += 4 * lhs_cols_offset;
- dst += 4 * row_address_offset - rhs_rows;
- }
- for (; i_items < lhs_rows; i_items++)
- {
- int32_t acc[4];
- const int32_t *multipliers = dst_multipliers;
- const int32_t *shifts = dst_shifts;
- for (int i = 0; i < rhs_rows; i++)
- {
- int32_t acc_n0 = 0;
- const int8_t *lhs_vec = lhs;
- const int8_t *col_base = rhs + i * rhs_cols;
- int32_t sum_tmp = 0;
- #if defined(ARM_MATH_AUTOVECTORIZE)
- for (int j = 0; j < rhs_cols; j++)
- {
- int32_t col = col_base[j];
- sum_tmp += col;
- acc_n0 += lhs_vec[j] * col;
- }
- #else
- __ASM volatile(" .p2align 2 \n"
- " wlstp.8 lr, %[cnt], 1f \n"
- " mov %[sum], 0 \n"
- " mov %[out0], 0 \n"
- " vldrb.8 q0, [%[col]], #16 \n"
- "2: \n"
- " vaddva.s8 %[sum], q0 \n"
- " vldrb.8 q1, [%[row0]], #16 \n"
- " vmladava.s8 %[out0], q0, q1 \n"
- " vldrb.8 q0, [%[col]], #16 \n"
- " letp lr, 2b \n"
- "1: \n"
- : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
- : [cnt] "r"(rhs_cols)
- : "q0", "q1", "memory", "r14");
- #endif
- sum_tmp *= lhs_offset;
- sum_tmp += acc_n0;
- if (bias)
- {
- sum_tmp += bias[i];
- }
- const int32_t index = i & 0x3;
- acc[index] = sum_tmp;
- if (index == 3)
- {
- int32x4_t res = vldrwq_s32(acc);
- res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
- multipliers += 4;
- shifts += 4;
- res = vaddq_n_s32(res, dst_offset);
- res = vmaxq_s32(res, vdupq_n_s32(activation_min));
- res = vminq_s32(res, vdupq_n_s32(activation_max));
- vstrbq_s32(dst, res);
- dst += 4;
- }
- }
- lhs += lhs_cols_offset;
- const int32_t tail_rows = rhs_rows & 0x3;
- for (int i = 0; i < tail_rows; i++)
- {
- int32_t acc_n0 = acc[i];
- acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
- acc_n0 += dst_offset;
- acc_n0 = MAX(acc_n0, activation_min);
- acc_n0 = MIN(acc_n0, activation_max);
- *dst++ = (int8_t)acc_n0;
- }
- dst += row_address_offset - rhs_rows;
- }
- #elif defined(ARM_MATH_DSP)
- (void)row_address_offset;
- const int32_t rhs_off0 = rhs_cols - 4;
- const int32_t lhs_off0 = lhs_cols_offset - 4;
- for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
- {
- const int8_t *lhs_ptr = &lhs[0];
- int8_t *dst_ptr = &dst[0];
- int32_t lhs_offset_contribution0 = 0;
- int32_t lhs_offset_contribution1 = 0;
- for (int32_t x = 0; x < rhs_cols; ++x)
- {
- lhs_offset_contribution0 += rhs[x];
- lhs_offset_contribution1 += rhs[x + rhs_cols];
- }
- lhs_offset_contribution0 *= lhs_offset;
- lhs_offset_contribution1 *= lhs_offset;
- if (bias)
- {
- lhs_offset_contribution0 += bias[rhs_rows_idx];
- lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
- }
- int32_t lhs_rows_idx = lhs_rows >> 1;
- while (lhs_rows_idx)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = lhs_offset_contribution0;
- int32_t res01 = lhs_offset_contribution1;
- int32_t res10 = lhs_offset_contribution0;
- int32_t res11 = lhs_offset_contribution1;
- int32_t rhs_cols_idx = 0;
- int32_t val0, val1, val2, val3, val4, val5;
- for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
- {
- val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val2 = SXTB16(val1);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val1 = SXTB16_RORn(val1, 8);
- val0 = SXTB16_RORn(val0, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val3, val2, res00);
- val5 = SXTB16(val4);
- res00 = SMLAD(val0, val1, res00);
- val4 = SXTB16_RORn(val4, 8);
- res01 = SMLAD(val3, val5, res01);
- res01 = SMLAD(val0, val4, res01);
- // 4 x MAC res10, res11
- val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
- val3 = SXTB16(val0);
- val0 = SXTB16_RORn(val0, 8);
- res10 = SMLAD(val3, val2, res10);
- res11 = SMLAD(val3, val5, res11);
- res10 = SMLAD(val0, val1, res10);
- val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- res11 = SMLAD(val0, val4, res11);
- val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = SXTB16(val1);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val1 = SXTB16_RORn(val1, 8);
- val0 = SXTB16_RORn(val0, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val3, val2, res00);
- val5 = SXTB16(val4);
- res00 = SMLAD(val0, val1, res00);
- val4 = SXTB16_RORn(val4, 8);
- res01 = SMLAD(val3, val5, res01);
- res01 = SMLAD(val0, val4, res01);
- // 4 x MAC res10, res11
- val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
- val3 = SXTB16(val0);
- val0 = SXTB16_RORn(val0, 8);
- res10 = SMLAD(val3, val2, res10);
- res11 = SMLAD(val3, val5, res11);
- res10 = SMLAD(val0, val1, res10);
- val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- res11 = SMLAD(val0, val4, res11);
- val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = SXTB16(val1);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val1 = SXTB16_RORn(val1, 8);
- val0 = SXTB16_RORn(val0, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val3, val2, res00);
- val5 = SXTB16(val4);
- res00 = SMLAD(val0, val1, res00);
- val4 = SXTB16_RORn(val4, 8);
- res01 = SMLAD(val3, val5, res01);
- res01 = SMLAD(val0, val4, res01);
- // 4 x MAC res10, res11
- val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
- val3 = SXTB16(val0);
- val0 = SXTB16_RORn(val0, 8);
- res10 = SMLAD(val3, val2, res10);
- res11 = SMLAD(val3, val5, res11);
- res10 = SMLAD(val0, val1, res10);
- val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- res11 = SMLAD(val0, val4, res11);
- val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = SXTB16(val1);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val1 = SXTB16_RORn(val1, 8);
- val0 = SXTB16_RORn(val0, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val3, val2, res00);
- val5 = SXTB16(val4);
- res00 = SMLAD(val0, val1, res00);
- val4 = SXTB16_RORn(val4, 8);
- res01 = SMLAD(val3, val5, res01);
- res01 = SMLAD(val0, val4, res01);
- // 4 x MAC res10, res11
- val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
- val3 = SXTB16(val0);
- val0 = SXTB16_RORn(val0, 8);
- res10 = SMLAD(val3, val2, res10);
- res11 = SMLAD(val3, val5, res11);
- res10 = SMLAD(val0, val1, res10);
- res11 = SMLAD(val0, val4, res11);
- }
- for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
- {
- val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val2 = SXTB16(val1);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val1 = SXTB16_RORn(val1, 8);
- val0 = SXTB16_RORn(val0, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val3, val2, res00);
- val5 = SXTB16(val4);
- res00 = SMLAD(val0, val1, res00);
- val4 = SXTB16_RORn(val4, 8);
- res01 = SMLAD(val3, val5, res01);
- res01 = SMLAD(val0, val4, res01);
- // 4 x MAC res10, res11
- val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
- val3 = SXTB16(val0);
- val0 = SXTB16_RORn(val0, 8);
- res10 = SMLAD(val3, val2, res10);
- res11 = SMLAD(val3, val5, res11);
- res10 = SMLAD(val0, val1, res10);
- res11 = SMLAD(val0, val4, res11);
- }
- for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
- {
- int8_t rhs_value0 = rhs_ptr[0];
- int8_t rhs_value1 = rhs_ptr[rhs_cols];
- int8_t lhs_value = lhs_ptr[0];
- res00 += lhs_value * rhs_value0;
- res01 += lhs_value * rhs_value1;
- lhs_value = lhs_ptr[lhs_cols_offset];
- res10 += lhs_value * rhs_value0;
- res11 += lhs_value * rhs_value1;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- // Add offset
- res00 += dst_offset;
- res01 += dst_offset;
- res10 += dst_offset;
- res11 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- res01 = MAX(res01, activation_min);
- res01 = MIN(res01, activation_max);
- res10 = MAX(res10, activation_min);
- res10 = MIN(res10, activation_max);
- res11 = MAX(res11, activation_min);
- res11 = MIN(res11, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr[1] = (int8_t)res01;
- dst_ptr += rhs_rows;
- dst_ptr[0] = (int8_t)res10;
- dst_ptr[1] = (int8_t)res11;
- dst_ptr += rhs_rows;
- lhs_ptr -= rhs_cols;
- lhs_ptr += 2 * lhs_cols_offset;
- lhs_rows_idx--;
- }
- // Left-over rows
- if (lhs_rows % 2)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = lhs_offset_contribution0;
- int32_t res01 = lhs_offset_contribution1;
- int32_t rhs_cols_idx = 0;
- int32_t val0, val1, val2, val3, val4, val5;
- for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
- {
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val5 = SXTB16(val2);
- val4 = SXTB16(val1);
- val0 = SXTB16_RORn(val0, 8);
- val2 = SXTB16_RORn(val2, 8);
- val1 = SXTB16_RORn(val1, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val5, val3, res00);
- res00 = SMLAD(val2, val0, res00);
- res01 = SMLAD(val5, val4, res01);
- res01 = SMLAD(val2, val1, res01);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val5 = SXTB16(val2);
- val4 = SXTB16(val1);
- val0 = SXTB16_RORn(val0, 8);
- val2 = SXTB16_RORn(val2, 8);
- val1 = SXTB16_RORn(val1, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val5, val3, res00);
- res00 = SMLAD(val2, val0, res00);
- res01 = SMLAD(val5, val4, res01);
- res01 = SMLAD(val2, val1, res01);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val5 = SXTB16(val2);
- val4 = SXTB16(val1);
- val0 = SXTB16_RORn(val0, 8);
- val2 = SXTB16_RORn(val2, 8);
- val1 = SXTB16_RORn(val1, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val5, val3, res00);
- res00 = SMLAD(val2, val0, res00);
- res01 = SMLAD(val5, val4, res01);
- res01 = SMLAD(val2, val1, res01);
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val5 = SXTB16(val2);
- val4 = SXTB16(val1);
- val0 = SXTB16_RORn(val0, 8);
- val2 = SXTB16_RORn(val2, 8);
- val1 = SXTB16_RORn(val1, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val5, val3, res00);
- res00 = SMLAD(val2, val0, res00);
- res01 = SMLAD(val5, val4, res01);
- res01 = SMLAD(val2, val1, res01);
- }
- for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
- {
- val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
- val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
- val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
- val3 = SXTB16(val0);
- val5 = SXTB16(val2);
- val4 = SXTB16(val1);
- val0 = SXTB16_RORn(val0, 8);
- val2 = SXTB16_RORn(val2, 8);
- val1 = SXTB16_RORn(val1, 8);
- // 4 x MAC res00, res01
- res00 = SMLAD(val5, val3, res00);
- res00 = SMLAD(val2, val0, res00);
- res01 = SMLAD(val5, val4, res01);
- res01 = SMLAD(val2, val1, res01);
- }
- // Left-over accumulations
- for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
- {
- int8_t rhs_value0 = rhs_ptr[0];
- int8_t rhs_value1 = rhs_ptr[rhs_cols];
- int8_t lhs_value = lhs_ptr[0];
- res00 += lhs_value * rhs_value0;
- res01 += lhs_value * rhs_value1;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- // Add offset
- res00 += dst_offset;
- res01 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- res01 = MAX(res01, activation_min);
- res01 = MIN(res01, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr[1] = (int8_t)res01;
- }
- rhs += 2 * rhs_cols;
- dst += 2;
- }
- if (rhs_rows % 2)
- {
- const int8_t *lhs_ptr = &lhs[0];
- int8_t *dst_ptr = &dst[0];
- for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = 0;
- if (bias)
- {
- res00 = bias[rhs_rows - 1];
- }
- for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
- {
- int32_t rhs_value = rhs_ptr[0];
- int32_t lhs_value = lhs_ptr[0] + lhs_offset;
- res00 += lhs_value * rhs_value;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- lhs_ptr -= rhs_cols;
- lhs_ptr += lhs_cols_offset;
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
- // Add offset
- res00 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr += rhs_rows;
- }
- }
- #else
- (void)row_address_offset;
- for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
- {
- const int8_t *lhs_ptr = &lhs[0];
- int8_t *dst_ptr = &dst[0];
- int32_t lhs_offset_contribution0 = 0;
- int32_t lhs_offset_contribution1 = 0;
- for (int32_t x = 0; x < rhs_cols; ++x)
- {
- lhs_offset_contribution0 += rhs[x];
- lhs_offset_contribution1 += rhs[x + rhs_cols];
- }
- lhs_offset_contribution0 *= lhs_offset;
- lhs_offset_contribution1 *= lhs_offset;
- if (bias)
- {
- lhs_offset_contribution0 += bias[rhs_rows_idx];
- lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
- }
- int32_t lhs_rows_idx = lhs_rows >> 1;
- while (lhs_rows_idx)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = lhs_offset_contribution0;
- int32_t res01 = lhs_offset_contribution1;
- int32_t res10 = lhs_offset_contribution0;
- int32_t res11 = lhs_offset_contribution1;
- for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
- {
- int8_t rhs_value0 = rhs_ptr[0];
- int8_t rhs_value1 = rhs_ptr[rhs_cols];
- int8_t lhs_value = lhs_ptr[0];
- res00 += lhs_value * rhs_value0;
- res01 += lhs_value * rhs_value1;
- lhs_value = lhs_ptr[lhs_cols_offset];
- res10 += lhs_value * rhs_value0;
- res11 += lhs_value * rhs_value1;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- // Add offset
- res00 += dst_offset;
- res01 += dst_offset;
- res10 += dst_offset;
- res11 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- res01 = MAX(res01, activation_min);
- res01 = MIN(res01, activation_max);
- res10 = MAX(res10, activation_min);
- res10 = MIN(res10, activation_max);
- res11 = MAX(res11, activation_min);
- res11 = MIN(res11, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr[1] = (int8_t)res01;
- dst_ptr += rhs_rows;
- dst_ptr[0] = (int8_t)res10;
- dst_ptr[1] = (int8_t)res11;
- dst_ptr += rhs_rows;
- lhs_ptr -= rhs_cols;
- lhs_ptr += 2 * lhs_cols_offset;
- lhs_rows_idx--;
- }
- // Left-over rows
- if (lhs_rows % 2)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = lhs_offset_contribution0;
- int32_t res01 = lhs_offset_contribution1;
- for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
- {
- int8_t rhs_value0 = rhs_ptr[0];
- int8_t rhs_value1 = rhs_ptr[rhs_cols];
- int8_t lhs_value = lhs_ptr[0];
- res00 += lhs_value * rhs_value0;
- res01 += lhs_value * rhs_value1;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
- res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
- // Add offset
- res00 += dst_offset;
- res01 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- res01 = MAX(res01, activation_min);
- res01 = MIN(res01, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr[1] = (int8_t)res01;
- }
- rhs += 2 * rhs_cols;
- dst += 2;
- }
- if (rhs_rows % 2)
- {
- const int8_t *lhs_ptr = &lhs[0];
- int8_t *dst_ptr = &dst[0];
- for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
- {
- const int8_t *rhs_ptr = &rhs[0];
- int32_t res00 = 0;
- if (bias)
- {
- res00 = bias[rhs_rows - 1];
- }
- for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
- {
- int32_t rhs_value = rhs_ptr[0];
- int32_t lhs_value = lhs_ptr[0] + lhs_offset;
- res00 += lhs_value * rhs_value;
- ++rhs_ptr;
- ++lhs_ptr;
- }
- lhs_ptr -= rhs_cols;
- lhs_ptr += lhs_cols_offset;
- // Quantize down
- res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
- // Add offset
- res00 += dst_offset;
- // Clamp the result
- res00 = MAX(res00, activation_min);
- res00 = MIN(res00, activation_max);
- dst_ptr[0] = (int8_t)res00;
- dst_ptr += rhs_rows;
- }
- }
- #endif
- return ARM_CMSIS_NN_SUCCESS;
- }
- /**
- * @} end of Doxygen group
- */
|