Explorar el Código

CMSIS-NN: Adding batch to conv kernel

Fredrik Knutsson hace 6 años
padre
commit
2d42573d5b

+ 34 - 30
CMSIS/NN/Include/arm_nnfunctions.h

@@ -123,20 +123,21 @@ extern    "C"
 
 /**
    * @brief Basic s8 convolution function
-   * @param[in]       input      pointer to input tensor. Range: int8, format: [H,W,in_ch]
-   * @param[in]       input_x    input tensor width
-   * @param[in]       input_y    input tensor height
-   * @param[in]       input_ch   number of input tensor channels
-   * @param[in]       kernel     pointer to kernel weights. Range: int8, format: [out_ch, H, W, in_ch]
-   * @param[in]       output_ch  number of filters, i.e., output tensor channels
-   * @param[in]       kernel_x   filter/kernel width
-   * @param[in]       kernel_y   filter/kernel height
-   * @param[in]       pad_x      padding along width
-   * @param[in]       pad_y      padding along height
-   * @param[in]       stride_x   convolution stride x
-   * @param[in]       stride_y   convolution stride y
-   * @param[in]       bias       pointer to per output channel bias. Range: int32
-   * @param[in,out]   output     pointer to output tensor. format: [H, W, out_ch]
+   * @param[in]       input           pointer to input tensor. Range: int8, format: [N,H,W,in_ch]
+   * @param[in]       input_x         input tensor width
+   * @param[in]       input_y         input tensor height
+   * @param[in]       input_ch        number of input tensor channels
+   * @param[in]       input_batches   number of input batches
+   * @param[in]       kernel          pointer to kernel weights. Range: int8, format: [out_ch, H, W, in_ch]
+   * @param[in]       output_ch       number of filters, i.e., output tensor channels
+   * @param[in]       kernel_x        filter/kernel width
+   * @param[in]       kernel_y        filter/kernel height
+   * @param[in]       pad_x           padding along width
+   * @param[in]       pad_y           padding along height
+   * @param[in]       stride_x        convolution stride x
+   * @param[in]       stride_y        convolution stride y
+   * @param[in]       bias            pointer to per output channel bias. Range: int32
+   * @param[in,out]   output          pointer to output tensor. format: [H, W, out_ch]
    * @param[in]       output_shift    pointer to per output channel requantization shift parameter.
    * @param[in]       output_mult     pointer to per output channel requantization multiplier parameter.
    * @param[in]       out_offset      output tensor offset. Range: int8
@@ -161,6 +162,7 @@ extern    "C"
                                const uint16_t input_x,
                                const uint16_t input_y,
                                const uint16_t input_ch,
+                               const uint16_t input_batches,
                                const q7_t *kernel,
                                const uint16_t output_ch,
                                const uint16_t kernel_x,
@@ -460,22 +462,23 @@ extern    "C"
 
   /**
    * @brief Fast s8 version for 1x1 convolution (non-square shape)
-   * @param[in]      input         pointer to input tensor.  Format: [H, W, in_ch]
-   * @param[in]      input_x       input tensor dimension x
-   * @param[in]      input_y       input tensor dimension y
-   * @param[in]      input_ch      number of input tensor channels
-   * @param[in]      kernel        pointer to kernel weights. Format: [out_ch, H, W, in_ch]
-   * @param[in]      output_ch     number of filters, i.e., output tensor channels
-   * @param[in]      pad_x         padding size x
-   * @param[in]      pad_y         padding size y
-   * @param[in]      stride_x      convolution stride x
-   * @param[in]      stride_y      convolution stride y
-   * @param[in]      bias          pointer to per channel bias. Range : int32
-   * @param[in,out]  output        pointer to output tensor.  Format: [H, W, out_ch]
-   * @param[in]      output_shift  pointer to per output channel requantization shift parameter.
-   * @param[in]      output_mult   pointer to per output channel requantization multiplier parameter.
-   * @param[in]      out_offset    output tensor offset. Range: int8
-   * @param[in]      input_offset input tensor offset. Range: int8
+   * @param[in]      input                pointer to input tensor.  Format: [N, H, W, in_ch]
+   * @param[in]      input_x              input tensor dimension x
+   * @param[in]      input_y              input tensor dimension y
+   * @param[in]      input_ch             number of input tensor channels
+   * @param[in]      input_batches        number of input batches
+   * @param[in]      kernel               pointer to kernel weights. Format: [out_ch, H, W, in_ch]
+   * @param[in]      output_ch            number of filters, i.e., output tensor channels
+   * @param[in]      pad_x                padding size x
+   * @param[in]      pad_y                padding size y
+   * @param[in]      stride_x             convolution stride x
+   * @param[in]      stride_y             convolution stride y
+   * @param[in]      bias                 pointer to per channel bias. Range : int32
+   * @param[in,out]  output               pointer to output tensor.  Format: [H, W, out_ch]
+   * @param[in]      output_shift         pointer to per output channel requantization shift parameter.
+   * @param[in]      output_mult          pointer to per output channel requantization multiplier parameter.
+   * @param[in]      out_offset           output tensor offset. Range: int8
+   * @param[in]      input_offset         input tensor offset. Range: int8
    * @param[in]      out_activation_min   Minimum value to clamp the output to. Range: int8
    * @param[in]      out_activation_max   Minimum value to clamp the output to. Range: int8
    * @param[in]      output_x  output tensor width
@@ -500,6 +503,7 @@ extern    "C"
                                         const uint16_t input_x,
                                         const uint16_t input_y,
                                         const uint16_t input_ch,
+                                        const uint16_t input_batches,
                                         const q7_t *kernel,
                                         const uint16_t output_ch,
                                         const uint16_t pad_x,

+ 88 - 82
CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_1x1_s8_fast.c

@@ -55,6 +55,7 @@ arm_status arm_convolve_1x1_s8_fast(const q7_t *input,
                                     const uint16_t input_x,
                                     const uint16_t input_y,
                                     const uint16_t input_ch,
+                                    const uint16_t input_batches,
                                     const q7_t *kernel,
                                     const uint16_t output_ch,
                                     const uint16_t pad_x,
@@ -79,102 +80,107 @@ arm_status arm_convolve_1x1_s8_fast(const q7_t *input,
         return ARM_MATH_SIZE_MISMATCH;
     }
 
+    int i_batch;
+    for (i_batch = 0; i_batch < input_batches; i_batch++)
+    {
+        input += i_batch * (input_x * input_y * input_ch);
+        output += i_batch * (output_x * output_y * output_ch);
 #if defined(ARM_MATH_LOOPUNROLL) && defined(ARM_MATH_DSP)
-    /* Optimized version for M cores with DSP extension */
-    int16_t i_out_y, i_out_x;
-    int16_t i_ch_out;
-    (void)input_y;
+        /* Optimized version for M cores with DSP extension */
+        int16_t i_out_y, i_out_x;
+        int16_t i_ch_out;
+        (void)input_y;
 
-    /* Partial(two columns) im2col buffer */
-    q15_t *two_column_buffer = buffer_a;
-    q7_t *out = output;
+        /* Partial(two columns) im2col buffer */
+        q15_t *two_column_buffer = buffer_a;
+        q7_t *out = output;
 
-    for (i_out_y = 0; i_out_y < output_y; i_out_y++)
-    {
-        for (i_out_x = 0; i_out_x < output_x; i_out_x++)
+        for (i_out_y = 0; i_out_y < output_y; i_out_y++)
         {
-            /* Fill buffer for partial im2col */
-            arm_q7_to_q15_reordered_with_offset(input + (i_out_y * input_x + i_out_x) * input_ch,
-                                                two_column_buffer,
-                                                input_ch,
-                                                (q7_t)input_offset);
-            two_column_buffer += input_ch;
-
-            if (two_column_buffer == buffer_a + 2 * input_ch * DIM_KER_X * DIM_KER_Y)
+            for (i_out_x = 0; i_out_x < output_x; i_out_x++)
             {
-                out = arm_nn_mat_mult_kernel_s8_s16_reordered(kernel,
-                                                              buffer_a,
-                                                              output_ch,
-                                                              output_shift,
-                                                              output_mult,
-                                                              (q7_t)out_offset,
-                                                              out_activation_min,
-                                                              out_activation_max,
-                                                              input_ch * DIM_KER_Y * DIM_KER_X,
-                                                              bias, out);
-                /* counter reset */
-                two_column_buffer = buffer_a;
+                /* Fill buffer for partial im2col */
+                arm_q7_to_q15_reordered_with_offset(input + (i_out_y * input_x + i_out_x) * input_ch,
+                                                    two_column_buffer,
+                                                    input_ch,
+                                                    (q7_t)input_offset);
+                two_column_buffer += input_ch;
+
+                if (two_column_buffer == buffer_a + 2 * input_ch * DIM_KER_X * DIM_KER_Y)
+                {
+                    out = arm_nn_mat_mult_kernel_s8_s16_reordered(kernel,
+                                                                buffer_a,
+                                                                output_ch,
+                                                                output_shift,
+                                                                output_mult,
+                                                                (q7_t)out_offset,
+                                                                out_activation_min,
+                                                                out_activation_max,
+                                                                input_ch * DIM_KER_Y * DIM_KER_X,
+                                                                bias, out);
+                    /* counter reset */
+                    two_column_buffer = buffer_a;
+                }
             }
         }
-    }
 
-    /* check if there is an odd column left-over for computation */
-    if (two_column_buffer != buffer_a)
-    {
-        const q7_t *ker_a = kernel;
-        for (i_ch_out = 0; i_ch_out < output_ch; i_ch_out++)
+        /* check if there is an odd column left-over for computation */
+        if (two_column_buffer != buffer_a)
         {
-            q31_t sum = bias[i_ch_out];
-
-            /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
-            const q15_t *ip_as_col = buffer_a;
-            uint16_t col_count = (input_ch * DIM_KER_X * DIM_KER_Y) >> 2;
-
-            while (col_count)
+            const q7_t *ker_a = kernel;
+            for (i_ch_out = 0; i_ch_out < output_ch; i_ch_out++)
             {
-                q31_t ker_a1, ker_a2;
-                q31_t in_b1, in_b2;
-                ker_a = read_and_pad_reordered(ker_a, &ker_a1, &ker_a2);
-
-                in_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
-                sum = __SMLAD(ker_a1, in_b1, sum);
-                in_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
-                sum = __SMLAD(ker_a2, in_b2, sum);
-
-                col_count--;
-            }
-            col_count = input_ch * DIM_KER_Y * DIM_KER_X & 0x3;
-            while (col_count)
-            {
-                q7_t ker_a1 = *ker_a++;
-                q15_t in_b1 = *ip_as_col++;
-                sum += ker_a1 * in_b1;
-                col_count--;
+                q31_t sum = bias[i_ch_out];
+
+                /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
+                const q15_t *ip_as_col = buffer_a;
+                uint16_t col_count = (input_ch * DIM_KER_X * DIM_KER_Y) >> 2;
+
+                while (col_count)
+                {
+                    q31_t ker_a1, ker_a2;
+                    q31_t in_b1, in_b2;
+                    ker_a = read_and_pad_reordered(ker_a, &ker_a1, &ker_a2);
+
+                    in_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
+                    sum = __SMLAD(ker_a1, in_b1, sum);
+                    in_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
+                    sum = __SMLAD(ker_a2, in_b2, sum);
+
+                    col_count--;
+                }
+                col_count = input_ch * DIM_KER_Y * DIM_KER_X & 0x3;
+                while (col_count)
+                {
+                    q7_t ker_a1 = *ker_a++;
+                    q15_t in_b1 = *ip_as_col++;
+                    sum += ker_a1 * in_b1;
+                    col_count--;
+                }
+                sum = arm_nn_requantize(sum, output_mult[i_ch_out], output_shift[i_ch_out]);
+                sum += out_offset;
+                sum = MAX(sum, out_activation_min);
+                sum = MIN(sum, out_activation_max);
+                *out++ = (q7_t)sum;
             }
-            sum = arm_nn_requantize(sum, output_mult[i_ch_out], output_shift[i_ch_out]);
-            sum += out_offset;
-            sum = MAX(sum, out_activation_min);
-            sum = MIN(sum, out_activation_max);
-            *out++ = (q7_t)sum;
         }
-    }
-
 #else
-    /* Run the following code as reference implementation for M cores with no DSP extension or when loop unrolling is
-       not to be done */
-    (void)buffer_a;
-    return arm_convolve_s8(input, input_x, input_y,
-                           input_ch, kernel, output_ch,
-                           DIM_KER_X, DIM_KER_Y,
-                           pad_x, pad_y,
-                           stride_x, stride_y,
-                           bias, output,
-                           output_shift, output_mult,
-                           out_offset, input_offset,
-                           out_activation_min, out_activation_max,
-                           output_x, output_y,
-                           NULL);
+        /* Run the following code as reference implementation for M cores with no DSP extension or when loop unrolling is
+        not to be done */
+        (void)buffer_a;
+        return arm_convolve_s8(input, input_x, input_y,
+                            input_ch, input_batches, kernel, output_ch,
+                            DIM_KER_X, DIM_KER_Y,
+                            pad_x, pad_y,
+                            stride_x, stride_y,
+                            bias, output,
+                            output_shift, output_mult,
+                            out_offset, input_offset,
+                            out_activation_min, out_activation_max,
+                            output_x, output_y,
+                            NULL);
 #endif
+    }
 
     /* Return to application */
     return ARM_MATH_SUCCESS;

+ 117 - 112
CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c

@@ -52,6 +52,7 @@ arm_status arm_convolve_s8(const q7_t *input,
                            const uint16_t input_x,
                            const uint16_t input_y,
                            const uint16_t input_ch,
+                           const uint16_t input_batches,
                            const q7_t *kernel,
                            const uint16_t output_ch,
                            const uint16_t kernel_x,
@@ -72,156 +73,160 @@ arm_status arm_convolve_s8(const q7_t *input,
                            const uint16_t output_y,
                            q15_t *buffer_a)
 {
-
+    int i_batch;
+    for (i_batch = 0; i_batch < input_batches; i_batch++)
+    {
+        input += i_batch * (input_x * input_y * input_ch);
+        output += i_batch * (output_x * output_y * output_ch);
 #if defined(ARM_MATH_LOOPUNROLL) && defined (ARM_MATH_DSP)
-    int16_t i_out_y, i_out_x, i_ker_y, i_ker_x;
+        int16_t i_out_y, i_out_x, i_ker_y, i_ker_x;
 
-    /* Generate two columns from the input tensor a GEMM computation */
-    q15_t *two_column_buf = buffer_a;
-    q7_t *out = output;
+        /* Generate two columns from the input tensor a GEMM computation */
+        q15_t *two_column_buf = buffer_a;
+        q7_t *out = output;
 
-    /* This part implements the im2col function */
-    for (i_out_y = 0; i_out_y < output_y; i_out_y++)
-    {
-        for (i_out_x = 0; i_out_x < output_x; i_out_x++)
+        /* This part implements the im2col function */
+        for (i_out_y = 0; i_out_y < output_y; i_out_y++)
         {
-            for (i_ker_y = i_out_y * stride_y - pad_y; i_ker_y < i_out_y * stride_y - pad_y + kernel_y; i_ker_y++)
+            for (i_out_x = 0; i_out_x < output_x; i_out_x++)
             {
-                for (i_ker_x = i_out_x * stride_x - pad_x; i_ker_x < i_out_x * stride_x - pad_x + kernel_x; i_ker_x++)
+                for (i_ker_y = i_out_y * stride_y - pad_y; i_ker_y < i_out_y * stride_y - pad_y + kernel_y; i_ker_y++)
                 {
-                    if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
+                    for (i_ker_x = i_out_x * stride_x - pad_x; i_ker_x < i_out_x * stride_x - pad_x + kernel_x; i_ker_x++)
                     {
-                        /* Filling 0 for out-of-bound paddings */
-                        memset(two_column_buf, 0, sizeof(q15_t) * input_ch);
-                    }
-                    else
-                    {
-                        /* Copying the pixel data to column */
-                        arm_q7_to_q15_with_offset(input + (i_ker_y * input_x + i_ker_x) * input_ch, two_column_buf, input_ch, input_offset);
+                        if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
+                        {
+                            /* Filling 0 for out-of-bound paddings */
+                            memset(two_column_buf, 0, sizeof(q15_t) * input_ch);
+                        }
+                        else
+                        {
+                            /* Copying the pixel data to column */
+                            arm_q7_to_q15_with_offset(input + (i_ker_y * input_x + i_ker_x) * input_ch, two_column_buf, input_ch, input_offset);
+                        }
+                        two_column_buf += input_ch;
                     }
-                    two_column_buf += input_ch;
                 }
-            }
 
-            /* Computation is filed for every 2 columns */
-            if (two_column_buf == buffer_a + 2 * input_ch * kernel_y * kernel_x)
-            {
-                out =
-                    arm_nn_mat_mult_kernel_s8_s16(kernel,
-                                                  buffer_a,
-                                                  output_ch,
-                                                  output_shift,
-                                                  output_mult,
-                                                  out_offset,
-                                                  out_activation_min,
-                                                  out_activation_max,
-                                                  input_ch * kernel_y * kernel_x,
-                                                  bias,
-                                                  out);
-
-                /* counter reset */
-                two_column_buf = buffer_a;
+                /* Computation is filed for every 2 columns */
+                if (two_column_buf == buffer_a + 2 * input_ch * kernel_y * kernel_x)
+                {
+                    out =
+                        arm_nn_mat_mult_kernel_s8_s16(kernel,
+                                                    buffer_a,
+                                                    output_ch,
+                                                    output_shift,
+                                                    output_mult,
+                                                    out_offset,
+                                                    out_activation_min,
+                                                    out_activation_max,
+                                                    input_ch * kernel_y * kernel_x,
+                                                    bias,
+                                                    out);
+
+                    /* counter reset */
+                    two_column_buf = buffer_a;
+                }
             }
         }
-    }
-
-    /* left-over because odd number of output pixels */
-    if (two_column_buf != buffer_a)
-    {
-        const q7_t *ker_a = kernel;
-        int i;
 
-        for (i = 0; i < output_ch; i++)
+        /* left-over because odd number of output pixels */
+        if (two_column_buf != buffer_a)
         {
-            /* Load the accumulator with bias first */
-            q31_t sum = bias[i];
+            const q7_t *ker_a = kernel;
+            int i;
+
+            for (i = 0; i < output_ch; i++)
+            {
+                /* Load the accumulator with bias first */
+                q31_t sum = bias[i];
 
-            /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
-            const q15_t *ip_as_col = buffer_a;
+                /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
+                const q15_t *ip_as_col = buffer_a;
 
-            /* 4 multiply and accumulates are done in one loop. */
-            uint16_t col_count = (input_ch * kernel_y * kernel_x) >> 2;
+                /* 4 multiply and accumulates are done in one loop. */
+                uint16_t col_count = (input_ch * kernel_y * kernel_x) >> 2;
 
-            while (col_count)
-            {
-                q31_t ker_a1, ker_a2;
-                q31_t ip_b1, ip_b2;
+                while (col_count)
+                {
+                    q31_t ker_a1, ker_a2;
+                    q31_t ip_b1, ip_b2;
 
-                ker_a = read_and_pad(ker_a, &ker_a1, &ker_a2);
+                    ker_a = read_and_pad(ker_a, &ker_a1, &ker_a2);
 
-                ip_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
-                sum = __SMLAD(ker_a1, ip_b1, sum);
-                ip_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
-                sum = __SMLAD(ker_a2, ip_b2, sum);
+                    ip_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
+                    sum = __SMLAD(ker_a1, ip_b1, sum);
+                    ip_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
+                    sum = __SMLAD(ker_a2, ip_b2, sum);
 
-                col_count--;
-            }
-            /* Handle left over mac */
-            col_count = input_ch * kernel_y * kernel_x & 0x3;
-            while (col_count)
-            {
-                q7_t ker_a1 = *ker_a++;
-                q15_t ip_b1 = *ip_as_col++;
-                sum += ker_a1 * ip_b1;
-                col_count--;
-            }
+                    col_count--;
+                }
+                /* Handle left over mac */
+                col_count = input_ch * kernel_y * kernel_x & 0x3;
+                while (col_count)
+                {
+                    q7_t ker_a1 = *ker_a++;
+                    q15_t ip_b1 = *ip_as_col++;
+                    sum += ker_a1 * ip_b1;
+                    col_count--;
+                }
 
-            sum = arm_nn_requantize(sum, output_mult[i], output_shift[i]);
-            sum += out_offset;
-            sum = MAX(sum, out_activation_min);
-            sum = MIN(sum, out_activation_max);
-            *out++ = (q7_t)sum;
+                sum = arm_nn_requantize(sum, output_mult[i], output_shift[i]);
+                sum += out_offset;
+                sum = MAX(sum, out_activation_min);
+                sum = MIN(sum, out_activation_max);
+                *out++ = (q7_t)sum;
+            }
         }
-    }
 #else
-    /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
-    (void)buffer_a;
-    int32_t i_out_ch, i_out_y, i_out_x, i_input_ch, i_ker_y, i_ker_x;
-    int32_t conv_out;
-    int32_t in_row, in_col;
+        /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
+        (void)buffer_a;
+        int32_t i_out_ch, i_out_y, i_out_x, i_input_ch, i_ker_y, i_ker_x;
+        int32_t conv_out;
+        int32_t in_row, in_col;
 
-    for (i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
-    {
-        for (i_out_y = 0; i_out_y < output_y; i_out_y++)
+        for (i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
         {
-            for (i_out_x = 0; i_out_x < output_x; i_out_x++)
+            for (i_out_y = 0; i_out_y < output_y; i_out_y++)
             {
-                conv_out = bias[i_out_ch];
+                for (i_out_x = 0; i_out_x < output_x; i_out_x++)
+                {
+                    conv_out = bias[i_out_ch];
 
-                const int32_t base_idx_y = stride_y * i_out_y - pad_y;
-                const int32_t base_idx_x = stride_x * i_out_x - pad_x;
+                    const int32_t base_idx_y = stride_y * i_out_y - pad_y;
+                    const int32_t base_idx_x = stride_x * i_out_x - pad_x;
 
-                const int32_t ker_y_start = MAX(0, -base_idx_y);
-                const int32_t ker_x_start = MAX(0, -base_idx_x);
+                    const int32_t ker_y_start = MAX(0, -base_idx_y);
+                    const int32_t ker_x_start = MAX(0, -base_idx_x);
 
-                const int32_t ker_y_end = MIN(kernel_y, input_y - base_idx_y);
-                const int32_t ker_x_end = MIN(kernel_x, input_x - base_idx_x);
+                    const int32_t ker_y_end = MIN(kernel_y, input_y - base_idx_y);
+                    const int32_t ker_x_end = MIN(kernel_x, input_x - base_idx_x);
 
-                for (i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
-                {
-                    for (i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
+                    for (i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
                     {
-                        const int32_t in_row = base_idx_y + i_ker_y;
-                        const int32_t in_col = base_idx_x + i_ker_x;
-                        for (i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
+                        for (i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
                         {
-                            conv_out +=
-                                (input[(in_row * input_x + in_col) * input_ch + i_input_ch] + input_offset) *
-                                kernel[i_out_ch * input_ch * kernel_y * kernel_x +
-                                       (i_ker_y * kernel_x + i_ker_x) * input_ch + i_input_ch];
+                            const int32_t in_row = base_idx_y + i_ker_y;
+                            const int32_t in_col = base_idx_x + i_ker_x;
+                            for (i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
+                            {
+                                conv_out +=
+                                    (input[(in_row * input_x + in_col) * input_ch + i_input_ch] + input_offset) *
+                                    kernel[i_out_ch * input_ch * kernel_y * kernel_x +
+                                        (i_ker_y * kernel_x + i_ker_x) * input_ch + i_input_ch];
+                            }
                         }
                     }
+                    conv_out = arm_nn_requantize(conv_out, output_mult[i_out_ch], output_shift[i_out_ch]);
+                    conv_out += out_offset;
+                    conv_out = MAX(conv_out, out_activation_min);
+                    conv_out = MIN(conv_out, out_activation_max);
+                    output[i_out_ch + (i_out_y * output_x + i_out_x) * output_ch] = (int8_t)conv_out;
                 }
-                conv_out = arm_nn_requantize(conv_out, output_mult[i_out_ch], output_shift[i_out_ch]);
-                conv_out += out_offset;
-                conv_out = MAX(conv_out, out_activation_min);
-                conv_out = MIN(conv_out, out_activation_max);
-                output[i_out_ch + (i_out_y * output_x + i_out_x) * output_ch] = (int8_t)conv_out;
             }
         }
-    }
-
 #endif
+    }
 
     /* Return to application */
     return ARM_MATH_SUCCESS;