Kaynağa Gözat

Add non zero filter offset support for FC (#110)

Fully connected can handle a non zero filter offset. Added a unit test
as well.

Co-authored-by: Adrian Lundell <adrian.lundell@arm.com>
Måns Nilsson 2 yıl önce
ebeveyn
işleme
72e1ebf623

+ 6 - 3
Include/arm_nnsupportfunctions.h

@@ -21,8 +21,8 @@
  * Title:        arm_nnsupportfunctions.h
  * Description:  Public header file of support functions for CMSIS NN Library
  *
- * $Date:        31 January 2024
- * $Revision:    V.18.1.0
+ * $Date:        14 February 2024
+ * $Revision:    V.19.0.0
  *
  * Target :  Arm(R) M-Profile Architecture
  * -------------------------------------------------------------------- */
@@ -529,6 +529,8 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
  * @param[in]      address_offset  Memory position offset for dst. First output is stored at 'dst', the
  *                                 second at 'dst + address_offset' and so on. Default value is typically 1.
+ * @param[in]      rhs_offset      Offset to be added to the input values of the right-hand side vector.
+ *                                 Range: -127 to 128
  *
  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
  *
@@ -546,7 +548,8 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
                                              const int32_t rhs_rows,
                                              const int32_t activation_min,
                                              const int32_t activation_max,
-                                             const int32_t address_offset);
+                                             const int32_t address_offset,
+                                             const int32_t rhs_offset);
 
 /**
  * @brief s16 Vector by Matrix (transposed) multiplication

+ 7 - 6
Source/FullyConnectedFunctions/arm_fully_connected_s8.c

@@ -1,5 +1,5 @@
 /*
- * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_fully_connected_s8
  * Description:  Fully connected function compatible with TF Lite.
  *
- * $Date:        23 October 2023
- * $Revision:    V.5.2.0
+ * $Date:        6 February 2024
+ * $Revision:    V.5.3.0
  *
  * Target :  Arm(R) M-Profile Architecture
  *
@@ -60,7 +60,6 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
                                            int8_t *output)
 {
     (void)bias_dims;
-    (void)fc_params->filter_offset;
 
     int32_t batch_cnt = input_dims->n;
 
@@ -71,10 +70,11 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
     }
 #endif
 
-    const int32_t *kernel_sum = (const int32_t *) ctx->buf;
+    const int32_t *kernel_sum = (const int32_t *)ctx->buf;
 
     while (batch_cnt)
     {
+
         arm_nn_vec_mat_mult_t_s8(input,
                                  kernel,
                                  kernel_sum,
@@ -88,7 +88,8 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
                                  output_dims->c, /* row_dim or output_depth */
                                  fc_params->activation.min,
                                  fc_params->activation.max,
-                                 1L);
+                                 1L,
+                                 fc_params->filter_offset);
 
         input += filter_dims->n;
         output += output_dims->c;

+ 614 - 260
Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_s8.c

@@ -1,5 +1,5 @@
 /*
- * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_nn_vec_mat_mult_t_s8
  * Description:  s8 vector by matrix (transposed) multiplication
  *
- * $Date:        5 May 2023
- * $Revision:    V.5.4.1
+ * $Date:        14 Feb 2023
+ * $Revision:    V.6.0.0
  *
  * Target :  Arm(R) M-Profile Architecture
  *
@@ -68,339 +68,693 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
                                              const int32_t rhs_rows,
                                              const int32_t activation_min,
                                              const int32_t activation_max,
-                                             const int32_t address_offset)
+                                             const int32_t address_offset,
+                                             const int32_t rhs_offset)
 {
+    if (rhs_offset)
+    {
 #if defined(ARM_MATH_MVEI)
-    const int32_t row_loop_cnt = rhs_rows / 4;
-    const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
+        const int32_t row_loop_cnt = rhs_rows / 4;
+        const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
 
-    for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
-    {
-        int32_t acc_0 = 0;
-        int32_t acc_1 = 0;
-        int32_t acc_2 = 0;
-        int32_t acc_3 = 0;
+        for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
+        {
+            int32_t acc_0 = 0;
+            int32_t acc_1 = 0;
+            int32_t acc_2 = 0;
+            int32_t acc_3 = 0;
 
-        const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
+            const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
 
-        const int8_t *lhs_vec = lhs;
-        const int8_t *rhs_0 = rhs;
-        const int8_t *rhs_1 = rhs + rhs_cols;
-        const int8_t *rhs_2 = rhs + 2 * rhs_cols;
-        const int8_t *rhs_3 = rhs + 3 * rhs_cols;
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_0_ptr = rhs;
+            const int8_t *rhs_1_ptr = rhs + rhs_cols;
+            const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
+            const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
 
-        if (bias)
-        {
-            acc_0 = *bias++;
-            acc_1 = *bias++;
-            acc_2 = *bias++;
-            acc_3 = *bias++;
-        }
+            int32_t lhs_sum = 0;
 
-        uint32_t col_cnt = (uint32_t)rhs_cols;
+            if (bias)
+            {
+                acc_0 = *bias++;
+                acc_1 = *bias++;
+                acc_2 = *bias++;
+                acc_3 = *bias++;
+            }
 
-        for (int i = 0; i < col_loop_cnt; i++)
-        {
-            mve_pred16_t p = vctp8q(col_cnt);
-            col_cnt -= 16;
+            uint32_t col_cnt = (uint32_t)rhs_cols;
 
-            const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+            for (int32_t i = 0; i < col_loop_cnt; i++)
+            {
+                mve_pred16_t p = vctp8q(col_cnt);
+                col_cnt -= 16;
 
-            const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
-            acc_0 = vmladavaq_s8(acc_0, ker_0, input);
+                const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+                lhs_sum = vaddvaq_s8(lhs_sum, input);
 
-            const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
-            acc_1 = vmladavaq_s8(acc_1, ker_1, input);
+                const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
+                acc_0 = vmladavaq_s8(acc_0, ker_0, input);
 
-            const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
-            acc_2 = vmladavaq_s8(acc_2, ker_2, input);
+                const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
+                acc_1 = vmladavaq_s8(acc_1, ker_1, input);
 
-            const int8x16_t ker_3 = vldrbq_z_s8(rhs_3, p);
-            acc_3 = vmladavaq_s8(acc_3, ker_3, input);
+                const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
+                acc_2 = vmladavaq_s8(acc_2, ker_2, input);
 
-            lhs_vec += 16;
-            rhs_0 += 16;
-            rhs_1 += 16;
-            rhs_2 += 16;
-            rhs_3 += 16;
-        }
-        rhs += 4 * rhs_cols;
+                const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
+                acc_3 = vmladavaq_s8(acc_3, ker_3, input);
 
-        int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
+                lhs_vec += 16;
+                rhs_0_ptr += 16;
+                rhs_1_ptr += 16;
+                rhs_2_ptr += 16;
+                rhs_3_ptr += 16;
+            }
+            rhs += 4 * rhs_cols;
 
-        const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
-        acc += vdupq_n_s32(lhs_offset) * rhs_sum;
-        kernel_sum += 4;
+            int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
 
-        acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
-        acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
-        acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
-        acc = vminq_s32(acc, vdupq_n_s32(activation_max));
+            const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
+            acc += vdupq_n_s32(lhs_offset) * rhs_sum;
+            kernel_sum += 4;
 
-        vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
+            acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
+            acc += vdupq_n_s32(rhs_offset * lhs_offset) * vdupq_n_s32(rhs_cols);
 
-        dst += 4 * address_offset;
-    }
+            acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
+            acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
+            acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
+            acc = vminq_s32(acc, vdupq_n_s32(activation_max));
 
-    const int loop_cnt = rhs_rows % 4;
-    for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
-    {
-        int32_t acc_0 = 0;
-        const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
-        const int8_t *lhs_vec = lhs;
-        const int8_t *rhs_0 = rhs;
-        uint32_t col_cnt = (uint32_t)rhs_cols;
+            vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
+
+            dst += 4 * address_offset;
+        }
 
-        for (int i = 0; i < col_loop_cnt; i++)
+        const int loop_cnt = rhs_rows % 4;
+        for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
         {
-            mve_pred16_t p = vctp8q(col_cnt);
-            col_cnt -= 16;
-            const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+            int32_t acc_0 = 0;
+            const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_ptr = rhs;
+            int32_t lhs_sum = 0;
+            uint32_t col_cnt = (uint32_t)rhs_cols;
+
+            for (int32_t i = 0; i < col_loop_cnt; i++)
+            {
+                mve_pred16_t p = vctp8q(col_cnt);
+                col_cnt -= 16;
+                const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+                lhs_sum = vaddvaq_s8(lhs_sum, input);
+
+                const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
+                acc_0 = vmladavaq_s8(acc_0, ker_0, input);
+
+                lhs_vec += 16;
+                rhs_ptr += 16;
+            }
+            rhs += rhs_cols;
+
+            if (bias)
+            {
+                acc_0 += *bias;
+                bias++;
+            }
+            const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
+            acc_0 += rhs_sum * lhs_offset;
+            acc_0 += lhs_sum * rhs_offset;
+            acc_0 += rhs_cols * lhs_offset * rhs_offset;
+
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+            acc_0 += dst_offset;
+
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            *dst = MIN(acc_0, activation_max);
+            dst += address_offset;
+        }
 
-            const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
-            acc_0 = vmladavaq_s8(acc_0, ker_0, input);
+#elif defined(ARM_MATH_DSP)
+        (void)kernel_sum;
+
+        const int32_t row_loop_cnt = rhs_rows / 2;
+        const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
+        const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
 
-            lhs_vec += 16;
-            rhs_0 += 16;
+        const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
+        const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
+
+        for (int32_t i = 0; i < row_loop_cnt; i++)
+        {
+            int32_t acc_0 = 0;
+            int32_t acc_1 = 0;
+            if (bias)
+            {
+                acc_0 = *bias++;
+                acc_1 = *bias++;
+            }
+
+            const int32_t col_loop_cnt = rhs_cols / 4;
+
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_0_ptr = rhs;
+            const int8_t *rhs_1_ptr = rhs + rhs_cols;
+            rhs += 2 * rhs_cols;
+
+            for (int32_t j = col_loop_cnt; j != 0; j--)
+            {
+                int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
+                int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
+
+                vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+
+                int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
+                int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
+                ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
+
+                acc_0 = SMLAD(ker_1, vec_1, acc_0);
+                acc_0 = SMLAD(ker_0, vec_0, acc_0);
+
+                ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
+                ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
+                ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
+
+                acc_1 = SMLAD(ker_1, vec_1, acc_1);
+                acc_1 = SMLAD(ker_0, vec_0, acc_1);
+            }
+
+            for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
+            {
+                const int32_t lhs_temp = (*lhs_vec + lhs_offset);
+                lhs_vec++;
+                acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
+                rhs_0_ptr++;
+                acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
+                rhs_1_ptr++;
+            }
+
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+            acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
+
+            // Add offset
+            acc_0 += dst_offset;
+            acc_1 += dst_offset;
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            acc_0 = MIN(acc_0, activation_max);
+            acc_1 = MAX(acc_1, activation_min);
+            acc_1 = MIN(acc_1, activation_max);
+            *dst = (int8_t)acc_0;
+            *(dst + address_offset) = (int8_t)acc_1;
+            dst += 2 * address_offset;
         }
-        rhs += rhs_cols;
 
-        if (bias)
+        if (rhs_rows & 0x1)
         {
-            acc_0 += *bias;
-            bias++;
+            int32_t acc_0 = 0;
+            if (bias)
+            {
+                acc_0 = *bias++;
+            }
+            const int32_t col_loop_cnt = rhs_cols / 4;
+
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_ptr = rhs;
+
+            for (int32_t i = col_loop_cnt; i != 0; i--)
+            {
+                int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
+                int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
+                vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+
+                int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
+                int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
+                ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
+
+                acc_0 = SMLAD(ker_1, vec_1, acc_0);
+                acc_0 = SMLAD(ker_0, vec_0, acc_0);
+            }
+
+            for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
+            {
+                const int32_t lhs_temp = (*lhs_vec + lhs_offset);
+                lhs_vec++;
+                acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
+                rhs_ptr++;
+            }
+
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+
+            // Add offset
+            acc_0 += dst_offset;
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            acc_0 = MIN(acc_0, activation_max);
+            *dst = (int8_t)acc_0;
+            dst += address_offset;
         }
-        const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
-        const int32_t offsets = rhs_sum * lhs_offset;
-        acc_0 += offsets;
-        acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
-        acc_0 += dst_offset;
-
-        // Clamp the result
-        acc_0 = MAX(acc_0, activation_min);
-        *dst = MIN(acc_0, activation_max);
-        dst += address_offset;
-    }
 
-#elif defined(ARM_MATH_DSP)
-    (void)kernel_sum;
+#else
+        (void)kernel_sum;
 
-    const int32_t row_loop_cnt = rhs_rows / 2;
-    const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
-    const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
+        const int32_t row_loop_cnt = rhs_rows / 3;
 
-    for (int32_t i = 0; i < row_loop_cnt; i++)
-    {
-        int32_t acc_0 = 0;
-        int32_t acc_1 = 0;
-        if (bias)
+        for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
         {
-            acc_0 = *bias++;
-            acc_1 = *bias++;
+            const int8_t *lhs_ptr = lhs;
+            const int8_t *rhs_ptr_0 = &rhs[0];
+            const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
+            const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
+
+            int32_t res00 = 0;
+            int32_t res01 = 0;
+            int32_t res02 = 0;
+            if (bias)
+            {
+                res00 = *bias++;
+                res01 = *bias++;
+                res02 = *bias++;
+            }
+            for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+            {
+                const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
+                const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
+                const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
+                const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
+
+                res00 += lhs_value * rhs_value0;
+                res01 += lhs_value * rhs_value1;
+                res02 += lhs_value * rhs_value2;
+
+                ++rhs_ptr_0;
+                ++rhs_ptr_1;
+                ++rhs_ptr_2;
+                ++lhs_ptr;
+            }
+
+            // Quantize down
+            res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
+            res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
+            res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
+
+            // Add offset
+            res00 += dst_offset;
+            res01 += dst_offset;
+            res02 += dst_offset;
+
+            // Clamp the result
+            res00 = MAX(res00, activation_min);
+            res00 = MIN(res00, activation_max);
+            res01 = MAX(res01, activation_min);
+            res01 = MIN(res01, activation_max);
+            res02 = MAX(res02, activation_min);
+            res02 = MIN(res02, activation_max);
+
+            *dst = (int8_t)res00;
+            *(dst + address_offset) = (int8_t)res01;
+            *(dst + 2 * address_offset) = (int8_t)res02;
+            dst += 3 * address_offset;
+
+            rhs += 3 * rhs_cols;
         }
 
-        const int32_t col_loop_cnt = rhs_cols / 4;
+        const int loop_cnt = rhs_rows % 3;
 
-        const int8_t *lhs_vec = lhs;
-        const int8_t *rhs_0 = rhs;
-        const int8_t *rhs_1 = rhs + rhs_cols;
-        rhs += 2 * rhs_cols;
-
-        for (int j = col_loop_cnt; j != 0; j--)
+        for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
         {
-            int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
-            int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
+            const int8_t *lhs_ptr = &lhs[0];
+            const int8_t *rhs_ptr = &rhs[0];
 
-            vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+            int32_t res00 = 0;
+            if (bias)
+            {
+                res00 = *bias++;
+            }
 
-            int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
-            int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
-            ker_0 = SXTB16(ker_0);
+            for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+            {
+                int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
+                int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
 
-            acc_0 = SMLAD(ker_1, vec_1, acc_0);
-            acc_0 = SMLAD(ker_0, vec_0, acc_0);
+                res00 += lhs_value * rhs_value0;
 
-            ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
-            ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
-            ker_0 = SXTB16(ker_0);
+                ++rhs_ptr;
+                ++lhs_ptr;
+            }
 
-            acc_1 = SMLAD(ker_1, vec_1, acc_1);
-            acc_1 = SMLAD(ker_0, vec_0, acc_1);
-        }
+            // Quantize down
+            res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
 
-        for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
-        {
-            const int32_t lhs_temp = (*lhs_vec + lhs_offset);
-            lhs_vec++;
-            acc_0 += lhs_temp * (*rhs_0);
-            rhs_0++;
-            acc_1 += lhs_temp * (*rhs_1);
-            rhs_1++;
-        }
+            // Add offset
+            res00 += dst_offset;
 
-        acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
-        acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
-
-        // Add offset
-        acc_0 += dst_offset;
-        acc_1 += dst_offset;
-        // Clamp the result
-        acc_0 = MAX(acc_0, activation_min);
-        acc_0 = MIN(acc_0, activation_max);
-        acc_1 = MAX(acc_1, activation_min);
-        acc_1 = MIN(acc_1, activation_max);
-        *dst = (int8_t)acc_0;
-        *(dst + address_offset) = (int8_t)acc_1;
-        dst += 2 * address_offset;
+            // Clamp the result
+            res00 = MAX(res00, activation_min);
+            res00 = MIN(res00, activation_max);
+
+            *dst = (int8_t)res00;
+            dst += address_offset;
+            rhs += rhs_cols;
+        }
+#endif
     }
 
-    if (rhs_rows & 0x1)
+    else
     {
-        int32_t acc_0 = 0;
-        if (bias)
-        {
-            acc_0 = *bias++;
-        }
-        const int32_t col_loop_cnt = rhs_cols / 4;
 
-        const int8_t *lhs_vec = lhs;
-        const int8_t *rhs_0 = rhs;
+#if defined(ARM_MATH_MVEI)
+        const int32_t row_loop_cnt = rhs_rows / 4;
+        const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
 
-        for (int i = col_loop_cnt; i != 0; i--)
+        for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
         {
-            int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
-            int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
-            vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+            int32_t acc_0 = 0;
+            int32_t acc_1 = 0;
+            int32_t acc_2 = 0;
+            int32_t acc_3 = 0;
+
+            const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
 
-            int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
-            int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
-            ker_0 = SXTB16(ker_0);
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_0_ptr = rhs;
+            const int8_t *rhs_1_ptr = rhs + rhs_cols;
+            const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
+            const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
 
-            acc_0 = SMLAD(ker_1, vec_1, acc_0);
-            acc_0 = SMLAD(ker_0, vec_0, acc_0);
+            if (bias)
+            {
+                acc_0 = *bias++;
+                acc_1 = *bias++;
+                acc_2 = *bias++;
+                acc_3 = *bias++;
+            }
+
+            uint32_t col_cnt = (uint32_t)rhs_cols;
+
+            for (int32_t i = 0; i < col_loop_cnt; i++)
+            {
+                mve_pred16_t p = vctp8q(col_cnt);
+                col_cnt -= 16;
+
+                const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+
+                const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
+                acc_0 = vmladavaq_s8(acc_0, ker_0, input);
+
+                const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
+                acc_1 = vmladavaq_s8(acc_1, ker_1, input);
+
+                const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
+                acc_2 = vmladavaq_s8(acc_2, ker_2, input);
+
+                const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
+                acc_3 = vmladavaq_s8(acc_3, ker_3, input);
+
+                lhs_vec += 16;
+                rhs_0_ptr += 16;
+                rhs_1_ptr += 16;
+                rhs_2_ptr += 16;
+                rhs_3_ptr += 16;
+            }
+            rhs += 4 * rhs_cols;
+
+            int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
+
+            const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
+            acc += vdupq_n_s32(lhs_offset) * rhs_sum;
+            kernel_sum += 4;
+
+            acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
+            acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
+            acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
+            acc = vminq_s32(acc, vdupq_n_s32(activation_max));
+
+            vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
+
+            dst += 4 * address_offset;
         }
 
-        for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
+        const int loop_cnt = rhs_rows % 4;
+        for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
         {
-            const int32_t lhs_temp = (*lhs_vec + lhs_offset);
-            lhs_vec++;
-            acc_0 += lhs_temp * (*rhs_0);
-            rhs_0++;
+            int32_t acc_0 = 0;
+            const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_ptr = rhs;
+            uint32_t col_cnt = (uint32_t)rhs_cols;
+
+            for (int32_t i = 0; i < col_loop_cnt; i++)
+            {
+                mve_pred16_t p = vctp8q(col_cnt);
+                col_cnt -= 16;
+                const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
+
+                const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
+                acc_0 = vmladavaq_s8(acc_0, ker_0, input);
+
+                lhs_vec += 16;
+                rhs_ptr += 16;
+            }
+            rhs += rhs_cols;
+
+            if (bias)
+            {
+                acc_0 += *bias;
+                bias++;
+            }
+            const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
+            const int32_t offsets = rhs_sum * lhs_offset;
+            acc_0 += offsets;
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+            acc_0 += dst_offset;
+
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            *dst = MIN(acc_0, activation_max);
+            dst += address_offset;
         }
 
-        acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
-
-        // Add offset
-        acc_0 += dst_offset;
-        // Clamp the result
-        acc_0 = MAX(acc_0, activation_min);
-        acc_0 = MIN(acc_0, activation_max);
-        *dst = (int8_t)acc_0;
-        dst += address_offset;
-    }
-
-#else
-    (void)kernel_sum;
+#elif defined(ARM_MATH_DSP)
+        (void)kernel_sum;
 
-    const int32_t row_loop_cnt = rhs_rows / 3;
+        const int32_t row_loop_cnt = rhs_rows / 2;
+        const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
+        const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
 
-    for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
-    {
-        const int8_t *lhs_ptr = lhs;
-        const int8_t *rhs_ptr_0 = &rhs[0];
-        const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
-        const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
-
-        int32_t res00 = 0;
-        int32_t res01 = 0;
-        int32_t res02 = 0;
-        if (bias)
+        for (int32_t i = 0; i < row_loop_cnt; i++)
         {
-            res00 = *bias++;
-            res01 = *bias++;
-            res02 = *bias++;
+            int32_t acc_0 = 0;
+            int32_t acc_1 = 0;
+            if (bias)
+            {
+                acc_0 = *bias++;
+                acc_1 = *bias++;
+            }
+
+            const int32_t col_loop_cnt = rhs_cols / 4;
+
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_0_ptr = rhs;
+            const int8_t *rhs_1_ptr = rhs + rhs_cols;
+            rhs += 2 * rhs_cols;
+
+            for (int32_t j = col_loop_cnt; j != 0; j--)
+            {
+                int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
+                int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
+
+                vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+
+                int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
+                int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
+                ker_0 = SXTB16(ker_0);
+
+                acc_0 = SMLAD(ker_1, vec_1, acc_0);
+                acc_0 = SMLAD(ker_0, vec_0, acc_0);
+
+                ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
+                ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
+                ker_0 = SXTB16(ker_0);
+
+                acc_1 = SMLAD(ker_1, vec_1, acc_1);
+                acc_1 = SMLAD(ker_0, vec_0, acc_1);
+            }
+
+            for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
+            {
+                const int32_t lhs_temp = (*lhs_vec + lhs_offset);
+                lhs_vec++;
+                acc_0 += lhs_temp * (*rhs_0_ptr);
+                rhs_0_ptr++;
+                acc_1 += lhs_temp * (*rhs_1_ptr);
+                rhs_1_ptr++;
+            }
+
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+            acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
+
+            // Add offset
+            acc_0 += dst_offset;
+            acc_1 += dst_offset;
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            acc_0 = MIN(acc_0, activation_max);
+            acc_1 = MAX(acc_1, activation_min);
+            acc_1 = MIN(acc_1, activation_max);
+            *dst = (int8_t)acc_0;
+            *(dst + address_offset) = (int8_t)acc_1;
+            dst += 2 * address_offset;
         }
-        for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+
+        if (rhs_rows & 0x1)
         {
-            const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
-            const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
-            const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
-            const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
-
-            res00 += lhs_value * rhs_value0;
-            res01 += lhs_value * rhs_value1;
-            res02 += lhs_value * rhs_value2;
-
-            ++rhs_ptr_0;
-            ++rhs_ptr_1;
-            ++rhs_ptr_2;
-            ++lhs_ptr;
+            int32_t acc_0 = 0;
+            if (bias)
+            {
+                acc_0 = *bias++;
+            }
+            const int32_t col_loop_cnt = rhs_cols / 4;
+
+            const int8_t *lhs_vec = lhs;
+            const int8_t *rhs_ptr = rhs;
+
+            for (int32_t i = col_loop_cnt; i != 0; i--)
+            {
+                int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
+                int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
+                vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
+
+                int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
+                int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
+                ker_0 = SXTB16(ker_0);
+
+                acc_0 = SMLAD(ker_1, vec_1, acc_0);
+                acc_0 = SMLAD(ker_0, vec_0, acc_0);
+            }
+
+            for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
+            {
+                const int32_t lhs_temp = (*lhs_vec + lhs_offset);
+                lhs_vec++;
+                acc_0 += lhs_temp * (*rhs_ptr);
+                rhs_ptr++;
+            }
+
+            acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
+
+            // Add offset
+            acc_0 += dst_offset;
+            // Clamp the result
+            acc_0 = MAX(acc_0, activation_min);
+            acc_0 = MIN(acc_0, activation_max);
+            *dst = (int8_t)acc_0;
+            dst += address_offset;
         }
-        // Quantize down
-        res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
-        res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
-        res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
-
-        // Add offset
-        res00 += dst_offset;
-        res01 += dst_offset;
-        res02 += dst_offset;
-
-        // Clamp the result
-        res00 = MAX(res00, activation_min);
-        res00 = MIN(res00, activation_max);
-        res01 = MAX(res01, activation_min);
-        res01 = MIN(res01, activation_max);
-        res02 = MAX(res02, activation_min);
-        res02 = MIN(res02, activation_max);
-
-        *dst = (int8_t)res00;
-        *(dst + address_offset) = (int8_t)res01;
-        *(dst + 2 * address_offset) = (int8_t)res02;
-        dst += 3 * address_offset;
-
-        rhs += 3 * rhs_cols;
-    }
 
-    const int loop_cnt = rhs_rows % 3;
+#else
+        (void)kernel_sum;
 
-    for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
-    {
-        const int8_t *lhs_ptr = &lhs[0];
-        const int8_t *rhs_ptr = &rhs[0];
+        const int32_t row_loop_cnt = rhs_rows / 3;
 
-        int32_t res00 = 0;
-        if (bias)
+        for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
         {
-            res00 = *bias++;
+            const int8_t *lhs_ptr = lhs;
+            const int8_t *rhs_ptr_0 = &rhs[0];
+            const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
+            const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
+
+            int32_t res00 = 0;
+            int32_t res01 = 0;
+            int32_t res02 = 0;
+            if (bias)
+            {
+                res00 = *bias++;
+                res01 = *bias++;
+                res02 = *bias++;
+            }
+            for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+            {
+                const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
+                const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
+                const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
+                const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
+
+                res00 += lhs_value * rhs_value0;
+                res01 += lhs_value * rhs_value1;
+                res02 += lhs_value * rhs_value2;
+
+                ++rhs_ptr_0;
+                ++rhs_ptr_1;
+                ++rhs_ptr_2;
+                ++lhs_ptr;
+            }
+            // Quantize down
+            res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
+            res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
+            res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
+
+            // Add offset
+            res00 += dst_offset;
+            res01 += dst_offset;
+            res02 += dst_offset;
+
+            // Clamp the result
+            res00 = MAX(res00, activation_min);
+            res00 = MIN(res00, activation_max);
+            res01 = MAX(res01, activation_min);
+            res01 = MIN(res01, activation_max);
+            res02 = MAX(res02, activation_min);
+            res02 = MIN(res02, activation_max);
+
+            *dst = (int8_t)res00;
+            *(dst + address_offset) = (int8_t)res01;
+            *(dst + 2 * address_offset) = (int8_t)res02;
+            dst += 3 * address_offset;
+
+            rhs += 3 * rhs_cols;
         }
 
-        for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+        const int loop_cnt = rhs_rows % 3;
+
+        for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
         {
-            int32_t rhs_value0 = (int8_t)rhs_ptr[0];
-            int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
+            const int8_t *lhs_ptr = &lhs[0];
+            const int8_t *rhs_ptr = &rhs[0];
 
-            res00 += lhs_value * rhs_value0;
+            int32_t res00 = 0;
+            if (bias)
+            {
+                res00 = *bias++;
+            }
 
-            ++rhs_ptr;
-            ++lhs_ptr;
-        }
+            for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
+            {
+                int32_t rhs_value0 = (int8_t)rhs_ptr[0];
+                int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
 
-        // Quantize down
-        res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
+                res00 += lhs_value * rhs_value0;
 
-        // Add offset
-        res00 += dst_offset;
+                ++rhs_ptr;
+                ++lhs_ptr;
+            }
 
-        // Clamp the result
-        res00 = MAX(res00, activation_min);
-        res00 = MIN(res00, activation_max);
+            // Quantize down
+            res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
 
-        *dst = (int8_t)res00;
-        dst += address_offset;
-        rhs += rhs_cols;
-    }
+            // Add offset
+            res00 += dst_offset;
+
+            // Clamp the result
+            res00 = MAX(res00, activation_min);
+            res00 = MIN(res00, activation_max);
+
+            *dst = (int8_t)res00;
+            dst += address_offset;
+            rhs += rhs_cols;
+        }
 #endif
+    }
     return ARM_CMSIS_NN_SUCCESS;
 }
 

+ 5 - 4
Source/SVDFunctions/arm_svdf_s8.c

@@ -1,5 +1,5 @@
 /*
- * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_svdf_s8.c
  * Description:  S8 basic SVDF layer function
  *
- * $Date:        5 September 2023
- * $Revision:    V.6.0.0
+ * $Date:        14 Feb 2024
+ * $Revision:    V.6.1.0
  *
  * Target :  Arm(R) M-Profile Architecture
  *
@@ -133,7 +133,8 @@ arm_cmsis_nn_status arm_svdf_s8(const cmsis_nn_context *ctx,
                                                            feature_batches,
                                                            in_activation_min,
                                                            in_activation_max,
-                                                           time_batches);
+                                                           time_batches,
+                                                           0);
 
         if (res != ARM_CMSIS_NN_SUCCESS)
         {

+ 1 - 1
Tests/UnitTest/TestCases/Common/fc_s4_weights_template.json → Tests/UnitTest/TestCases/Common/fc_weights_template.json

@@ -35,7 +35,7 @@
             output_size,
             input_size
           ],
-          "type": "INT4",
+          "type": "w_type",
           "buffer": 1,
           "name" : "tensor_weight",
           "quantization": {

+ 1 - 1
Tests/UnitTest/TestCases/Common/fc_s4_weights_template_null_bias.json → Tests/UnitTest/TestCases/Common/fc_weights_template_null_bias.json

@@ -35,7 +35,7 @@
             output_size,
             input_size
           ],
-          "type": "INT4",
+          "type": "w_type",
           "buffer": 1,
           "name" : "tensor_weight",
           "quantization": {

+ 0 - 0
Tests/UnitTest/TestCases/Common/fc_s4_weights_template_null_bias_unpacked.json → Tests/UnitTest/TestCases/Common/fc_weights_template_null_bias_unpacked.json


+ 0 - 0
Tests/UnitTest/TestCases/Common/fc_s4_weights_template_unpacked.json → Tests/UnitTest/TestCases/Common/fc_weights_template_unpacked.json


+ 6 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/biases_data.h

@@ -0,0 +1,6 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#pragma once
+#include <stdint.h>
+
+const int32_t *fully_connected_w_zp_biases = NULL;

+ 18 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/config_data.h

@@ -0,0 +1,18 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#pragma once
+#define FULLY_CONNECTED_W_ZP_OUT_CH 6
+#define FULLY_CONNECTED_W_ZP_IN_CH 10
+#define FULLY_CONNECTED_W_ZP_INPUT_W 2
+#define FULLY_CONNECTED_W_ZP_INPUT_H 1
+#define FULLY_CONNECTED_W_ZP_DST_SIZE 18
+#define FULLY_CONNECTED_W_ZP_INPUT_SIZE 20
+#define FULLY_CONNECTED_W_ZP_OUT_ACTIVATION_MIN -128
+#define FULLY_CONNECTED_W_ZP_OUT_ACTIVATION_MAX 127
+#define FULLY_CONNECTED_W_ZP_INPUT_BATCHES 3
+#define FULLY_CONNECTED_W_ZP_OUTPUT_MULTIPLIER 1417628845
+#define FULLY_CONNECTED_W_ZP_OUTPUT_SHIFT -7
+#define FULLY_CONNECTED_W_ZP_ACCUMULATION_DEPTH 20
+#define FULLY_CONNECTED_W_ZP_INPUT_OFFSET -2
+#define FULLY_CONNECTED_W_ZP_FILTER_OFFSET -15
+#define FULLY_CONNECTED_W_ZP_OUTPUT_OFFSET 35

+ 9 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/input_data.h

@@ -0,0 +1,9 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#pragma once
+#include <stdint.h>
+
+const int8_t fully_connected_w_zp_input[60] = {
+    24,   -17, 19,  -78,  -113, 35, -125, -40,  77,  -59,  -46, -56, -128, 25,  59,   79, 122, 59,  -46, -37,
+    37,   -10, -56, -100, -26,  -9, -52,  -128, -55, -122, 24,  4,   65,   31,  124,  87, -55, 96,  120, 35,
+    -104, -18, 4,   -90,  35,   82, -111, -111, -31, -117, 20,  -84, -29,  -45, -118, 86, -47, -50, -69, -35};

+ 7 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/output_ref_data.h

@@ -0,0 +1,7 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#pragma once
+#include <stdint.h>
+
+const int8_t fully_connected_w_zp_output_ref[18] =
+    {-3, 127, -87, 86, 127, -21, 86, 0, -74, 94, 127, -92, 127, 36, 127, 127, 127, 34};

+ 7 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/test_data.h

@@ -0,0 +1,7 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#include "biases_data.h"
+#include "config_data.h"
+#include "input_data.h"
+#include "output_ref_data.h"
+#include "weights_data.h"

+ 12 - 0
Tests/UnitTest/TestCases/TestData/fully_connected_w_zp/weights_data.h

@@ -0,0 +1,12 @@
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-2-g0b15fdfcb3f.
+#pragma once
+#include <stdint.h>
+
+const int8_t fully_connected_w_zp_weights[120] = {
+    -44,  -43,  22,   -123, -26, -126, -6,   -8,  -94, -46, -15,  89,   76,   -47,  -114, 28,   49,   -54, 4,    8,
+    124,  -96,  81,   46,   -99, 95,   -107, -58, 48,  116, 32,   -32,  -128, -84,  58,   -45,  39,   -40, 111,  -56,
+    -92,  -128, 57,   -33,  -1,  15,   38,   89,  109, 37,  -99,  123,  64,   -110, -101, 64,   -116, -19, 91,   -89,
+    -102, -31,  -101, -76,  -27, 68,   -112, 41,  49,  -42, 30,   -122, 109,  -89,  31,   -52,  -127, 9,   -120, -17,
+    64,   45,   -2,   51,   -97, -29,  -128, -93, 55,  -77, -11,  34,   -16,  0,    -78,  -81,  -115, 96,  64,   -96,
+    -50,  52,   -38,  116,  98,  102,  85,   -86, 106, 54,  -122, -83,  -22,  65,   -23,  -113, -19,  80,  41,   -51};

+ 3 - 1
Tests/UnitTest/TestCases/test_arm_fully_connected_s8/Unity/unity_test_arm_fully_connected_s8.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright 2010-2021, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -46,6 +46,8 @@ void tearDown(void) {}
 
 void test_fully_connected_arm_fully_connected_s8(void) { fully_connected_arm_fully_connected_s8(); }
 
+void test_fully_connected_w_zp_arm_fully_connected_s8(void) { fully_connected_w_zp_arm_fully_connected_s8(); }
+
 void test_fully_connected_mve_0_arm_fully_connected_s8(void) { fully_connected_mve_0_arm_fully_connected_s8(); }
 
 void test_fully_connected_mve_1_arm_fully_connected_s8(void) { fully_connected_mve_1_arm_fully_connected_s8(); }

+ 69 - 0
Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c

@@ -25,6 +25,7 @@
 #include "../TestData/fully_connected_mve_1/test_data.h"
 #include "../TestData/fully_connected_null_bias_0/test_data.h"
 #include "../TestData/fully_connected_out_activation/test_data.h"
+#include "../TestData/fully_connected_w_zp/test_data.h"
 #include "../Utils/validate.h"
 
 void fully_connected_arm_fully_connected_s8(void)
@@ -95,6 +96,74 @@ void fully_connected_arm_fully_connected_s8(void)
     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
 }
 
+void fully_connected_w_zp_arm_fully_connected_s8(void)
+{
+    const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
+    int8_t output[FULLY_CONNECTED_W_ZP_DST_SIZE] = {0};
+
+    cmsis_nn_context ctx;
+    cmsis_nn_fc_params fc_params;
+    cmsis_nn_per_tensor_quant_params quant_params;
+    cmsis_nn_dims input_dims;
+    cmsis_nn_dims filter_dims;
+    cmsis_nn_dims bias_dims;
+    cmsis_nn_dims output_dims;
+
+    const int32_t *bias_data = fully_connected_w_zp_biases;
+    const int8_t *kernel_data = fully_connected_w_zp_weights;
+    const int8_t *input_data = fully_connected_w_zp_input;
+    const int8_t *output_ref = fully_connected_w_zp_output_ref;
+    const int32_t output_ref_size = FULLY_CONNECTED_W_ZP_DST_SIZE;
+
+    input_dims.n = FULLY_CONNECTED_W_ZP_INPUT_BATCHES;
+    input_dims.w = FULLY_CONNECTED_W_ZP_INPUT_W;
+    input_dims.h = FULLY_CONNECTED_W_ZP_INPUT_H;
+    input_dims.c = FULLY_CONNECTED_W_ZP_IN_CH;
+    filter_dims.n = FULLY_CONNECTED_W_ZP_ACCUMULATION_DEPTH;
+    filter_dims.c = FULLY_CONNECTED_W_ZP_OUT_CH;
+    output_dims.n = FULLY_CONNECTED_W_ZP_INPUT_BATCHES;
+    output_dims.c = FULLY_CONNECTED_W_ZP_OUT_CH;
+
+    fc_params.input_offset = FULLY_CONNECTED_W_ZP_INPUT_OFFSET;
+    fc_params.filter_offset = FULLY_CONNECTED_W_ZP_FILTER_OFFSET;
+    fc_params.output_offset = FULLY_CONNECTED_W_ZP_OUTPUT_OFFSET;
+    fc_params.activation.min = FULLY_CONNECTED_W_ZP_OUT_ACTIVATION_MIN;
+    fc_params.activation.max = FULLY_CONNECTED_W_ZP_OUT_ACTIVATION_MAX;
+
+    quant_params.multiplier = FULLY_CONNECTED_W_ZP_OUTPUT_MULTIPLIER;
+    quant_params.shift = FULLY_CONNECTED_W_ZP_OUTPUT_SHIFT;
+
+    const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
+    ctx.buf = malloc(buf_size);
+    ctx.size = buf_size;
+
+#if defined(ARM_MATH_MVEI)
+    int32_t *buf = ctx.buf;
+    TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
+#endif
+
+    arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
+                                                        &fc_params,
+                                                        &quant_params,
+                                                        &input_dims,
+                                                        input_data,
+                                                        &filter_dims,
+                                                        kernel_data,
+                                                        &bias_dims,
+                                                        bias_data,
+                                                        &output_dims,
+                                                        output);
+
+    if (ctx.buf)
+    {
+        // The caller is responsible to clear the scratch buffers for security reasons if applicable.
+        memset(ctx.buf, 0, buf_size);
+        free(ctx.buf);
+    }
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
+}
+
 void fully_connected_mve_0_arm_fully_connected_s8(void)
 {
     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;

+ 11 - 13
Tests/UnitTest/conv_settings.py

@@ -93,18 +93,15 @@ class ConvSettings(TestSettings):
                 raise RuntimeError("out channel ({}) is not multiple of in channel ({})".format(out_ch, in_ch))
             if groups != 1:
                 raise RuntimeError("ERROR: Groups cannot be used for depthwise convolution")
+        else:
+            self.channel_multiplier = 0
 
         self.filter_ch = in_ch // groups
         if in_ch % groups != 0:
-            print(in_ch)
-            print(groups)
             raise RuntimeError("ERROR: Number of input channels must be an even multiple of groups")
         if out_ch % groups != 0:
             raise RuntimeError("ERROR: Number of output channels must be an even multiple of groups")
 
-        else:
-            self.channel_multiplier = 0
-
         if self.int4_weights:
             if self.test_type == 'conv':
                 self.json_template = "TestCases/Common/conv2d_s4_weights_template.json"
@@ -149,7 +146,6 @@ class ConvSettings(TestSettings):
 
         return per_channel_multiplier, per_channel_shift
 
-    # TODO
     def quantize_float_data(self, data=None, quantization_bit_range=8, quantization_type="affine", tf_tensor=False):
         if data is not None:
             if tf_tensor:
@@ -162,13 +158,13 @@ class ConvSettings(TestSettings):
                 data_max = max(data_max, 0.0)
 
                 scale = (data_max - data_min) / (pow(2, quantization_bit_range) - 1)
-                zero_point = -(round(data_max * scale)) - pow(2, quantization_bit_range-1)
-                zero_point = max(zero_point, pow(quantization_bit_range-1) - 1)
-                zero_point = min(zero_point, -pow(quantization_bit_range-1))
+                zero_point = -(round(data_max * scale)) - pow(2, quantization_bit_range - 1)
+                zero_point = max(zero_point, pow(quantization_bit_range - 1) - 1)
+                zero_point = min(zero_point, -pow(quantization_bit_range - 1))
 
             elif quantization_type.lower() == "symmetric":
                 absolute_max = max(abs(data_min), abs(data_max))
-                scale = absolute_max / (pow(2, quantization_bit_range-1) - 1)
+                scale = absolute_max / (pow(2, quantization_bit_range - 1) - 1)
                 zero_point = 0
 
             else:
@@ -283,7 +279,8 @@ class ConvSettings(TestSettings):
                 generated_json = self.generate_json_from_template(
                     None, weights, int8_time_weights=True, bias_data=biases, bias_buffer=3)
             else:
-                generated_json = self.generate_json_from_template(weights, int8_time_weights=False, bias_data=quant_bias, bias_buffer=2)
+                generated_json = self.generate_json_from_template(weights, int8_time_weights=False,
+                                                                  bias_data=quant_bias, bias_buffer=2)
 
             self.flatc_generate_tflite(generated_json, self.schema_file)
 
@@ -317,7 +314,7 @@ class ConvSettings(TestSettings):
                                                     padding=self.padding,
                                                     input_shape=input_shape[1:],
                                                     dilation_rate=(self.dilation_y, self.dilation_x),
-                                                groups=self.groups)
+                                                    groups=self.groups)
                 model.add(conv_layer)
                 conv_layer.set_weights([weights, biases])
             elif self.test_type == 'depthwise_conv':
@@ -335,7 +332,8 @@ class ConvSettings(TestSettings):
                                                                         strides=(self.stride_y, self.stride_x),
                                                                         padding=self.padding,
                                                                         input_shape=input_shape[1:],
-                                                                        dilation_rate=(self.dilation_y, self.dilation_x),
+                                                                        dilation_rate=(self.dilation_y,
+                                                                                       self.dilation_x),
                                                                         use_bias=self.generate_bias)
                 model.add(transposed_conv_layer)
                 if self.generate_bias:

+ 38 - 5
Tests/UnitTest/fully_connected_settings.py

@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -85,17 +85,22 @@ class FullyConnectedSettings(TestSettings):
                          interpreter=interpreter,
                          int4_weights=int4_weights)
 
-        if self.int4_weights:
+        self.filter_zero_point = w_zp
+
+        if self.int4_weights or self.filter_zero_point:
             if self.generate_bias:
-                self.json_template = "TestCases/Common/fc_s4_weights_template.json"
+                self.json_template = "TestCases/Common/fc_weights_template.json"
             else:
-                self.json_template = "TestCases/Common/fc_s4_weights_template_null_bias.json"
+                self.json_template = "TestCases/Common/fc_weights_template_null_bias.json"
+
+            weight_type = "INT4" if self.int4_weights else "INT8"
 
             self.json_replacements = {
                 "batches": batches,
                 "input_size": in_ch * x_in * y_in,
                 "input_scale": input_scale,
                 "input_zp": input_zp,
+                "w_type": weight_type,
                 "w_scale": w_scale,
                 "w_zp": w_zp,
                 "bias_size": out_ch,
@@ -118,6 +123,7 @@ class FullyConnectedSettings(TestSettings):
             f.write("#define {}_OUTPUT_SHIFT {}\n".format(prefix, self.quantized_shift))
             f.write("#define {}_ACCUMULATION_DEPTH {}\n".format(prefix, self.input_ch * self.x_input * self.y_input))
             f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, -self.input_zero_point))
+            f.write("#define {}_FILTER_OFFSET {}\n".format(prefix, -self.filter_zero_point))
             f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point))
 
     def quantize_multiplier(self, weights_scale):
@@ -151,7 +157,30 @@ class FullyConnectedSettings(TestSettings):
         else:
             biases = None
 
-        if self.int4_weights:
+        if self.filter_zero_point:
+            temp1 = self.model_path
+            temp2 = self.json_template
+
+            fc_weights_format = [self.input_ch * self.y_input * self.x_input * self.output_ch]
+            if weights is not None:
+                weights = tf.reshape(weights, fc_weights_format)
+            else:
+                weights = self.get_randomized_data(fc_weights_format,
+                                                   self.kernel_table_file,
+                                                   minrange=TestSettings.INT8_MIN,
+                                                   maxrange=TestSettings.INT8_MAX,
+                                                   regenerate=self.regenerate_new_weights)
+
+            self.model_path = self.model_path
+            self.json_template = self.json_template
+            generated_json = self.generate_json_from_template(weights, bias_data=biases, bias_buffer=2)
+            self.flatc_generate_tflite(generated_json, self.schema_file)
+
+            weights_size = weights.numpy().size
+            filter_index = 1
+            bias_index = 2
+
+        elif self.int4_weights:
             # Generate weights, both packed and unpacked model from JSON
             temp1 = self.model_path
             temp2 = self.json_template
@@ -226,6 +255,10 @@ class FullyConnectedSettings(TestSettings):
            (self.generate_bias and biases.numpy().size != interpreter.get_tensor(bias_layer['index']).size):
             raise RuntimeError(f"Dimension mismatch for {self.testdataset}")
 
+        weights_zero_point = filter_layer['quantization_parameters']['zero_points'][0]
+        if weights_zero_point != self.filter_zero_point:
+            raise RuntimeError(f"Filter zero point point mismatch for {self.filter_zero_point}")
+
         self.x_output = 1
         self.y_output = 1
 

+ 21 - 0
Tests/UnitTest/generate_test_data.py

@@ -2028,6 +2028,27 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases,
                                                     y_in=1,
                                                     batches=3,
                                                     interpreter=interpreter)
+    dataset = 'fully_connected_w_zp'
+    testdata_sets[dataset] = FullyConnectedSettings(dataset,
+                                                    type_of_test,
+                                                    regenerate_weights,
+                                                    regenerate_input,
+                                                    regenerate_biases,
+                                                    schema_file,
+                                                    in_ch=10,
+                                                    out_ch=6,
+                                                    x_in=2,
+                                                    y_in=1,
+                                                    batches=3,
+                                                    input_scale=0.034,
+                                                    w_scale=0.054,
+                                                    bias_scale=0.00000001,
+                                                    output_scale=0.356,
+                                                    input_zp=2,
+                                                    output_zp=35,
+                                                    w_zp=15,
+                                                    generate_bias=False,
+                                                    interpreter=interpreter)
     dataset = 'fully_connected_mve_0'
     testdata_sets[dataset] = FullyConnectedSettings(dataset,
                                                     type_of_test,