Sfoglia il codice sorgente

performance optimizations for arm_mat_mult_fast_q15 and arm_mat_mult_fast_q31

David Palframan 9 anni fa
parent
commit
9bf9a9ec1a

+ 11 - 0
CMSIS/DSP/Include/arm_math.h

@@ -1031,6 +1031,17 @@ extern "C"
                        ((((q31_t)x <<  8) >>  8) & (q31_t)0xFFFF0000)  ));
   }
 
+  /*
+   * @brief C custom defined SMMLA for M3 and M0 processors
+   */
+  CMSIS_INLINE __STATIC_INLINE int32_t __SMMLA(
+  int32_t x,
+  int32_t y,
+  int32_t sum)
+  {
+    return (sum + (int32_t) (((int64_t) x * y) >> 32));
+  }
+
 #endif /* defined (ARM_MATH_CM3) || defined (ARM_MATH_CM0_FAMILY) */
 
 

+ 191 - 23
CMSIS/DSP/Source/MatrixFunctions/arm_mat_mult_fast_q15.c

@@ -97,13 +97,16 @@ arm_status arm_mat_mult_fast_q15(
   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
   uint16_t numRowsB = pSrcB->numRows;            /* number of rows of input matrix A    */
-  uint16_t col, i = 0u, row = numRowsB, colCnt;  /* loop counters */
+  uint32_t col, i = 0u, row = numRowsB, colCnt;  /* loop counters */
   arm_status status;                             /* status of matrix multiplication */
 
 #ifndef UNALIGNED_SUPPORT_DISABLE
 
   q31_t in;                                      /* Temporary variable to hold the input value */
   q31_t inA1, inA2, inB1, inB2;
+  q31_t sum2, sum3, sum4;
+  q15_t *pInA2, *pInB2, *px2;
+  uint32_t j = 0;
 
 #else
 
@@ -269,9 +272,15 @@ arm_status arm_mat_mult_fast_q15(
     i = 0u;
     px = pDst->pData;
 
+#ifndef UNALIGNED_SUPPORT_DISABLE
+    /* Process two rows from matrix A at a time and output two rows at a time */
+    row = row >> 1;
+    px2 = px + numColsB;
+#endif
+
     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
     /* row loop */
-    do
+    while(row > 0u)
     {
       /* For every row wise process, the column loop counter is to be initiated */
       col = numColsB;
@@ -280,18 +289,35 @@ arm_status arm_mat_mult_fast_q15(
        ** to the starting address of the transposed pSrcB data */
       pInB = pSrcBT;
 
+#ifndef UNALIGNED_SUPPORT_DISABLE
+      /* Process two (transposed) columns from matrix B at a time */
+      col = col >> 1;
+      j = 0;
+#endif
+
       /* column loop */
-      do
+      while (col > 0u)
       {
         /* Set the variable sum, that acts as accumulator, to zero */
         sum = 0;
 
-        /* Apply loop unrolling and compute 2 MACs simultaneously. */
-        colCnt = numColsA >> 2;
-
-        /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
+        /* Initiate the pointer pInA to point to the starting address of the column being processed */
         pInA = pSrcA->pData + i;
 
+#ifndef UNALIGNED_SUPPORT_DISABLE
+        sum2 = 0;
+        sum3 = 0;
+        sum4 = 0;
+        pInB  = pSrcBT + j;
+        pInA2 = pInA + numColsA;
+        pInB2 = pInB + numRowsB;
+        
+        /* Read in two elements at once - alows dual MAC instruction */
+        colCnt = numColsA >> 1;
+#else
+        colCnt = numColsA >> 2;
+#endif
+
         /* matrix multiplication */
         while(colCnt > 0u)
         {
@@ -300,29 +326,35 @@ arm_status arm_mat_mult_fast_q15(
 
           inA1 = *__SIMD32(pInA)++;
           inB1 = *__SIMD32(pInB)++;
-          inA2 = *__SIMD32(pInA)++;
-          inB2 = *__SIMD32(pInB)++;
+          inA2 = *__SIMD32(pInA2)++;
+          inB2 = *__SIMD32(pInB2)++;
 
-          sum = __SMLAD(inA1, inB1, sum);
-          sum = __SMLAD(inA2, inB2, sum);
+          sum  = __SMLAD(inA1, inB1, sum);
+          sum2 = __SMLAD(inA1, inB2, sum2);
+          sum3 = __SMLAD(inA2, inB1, sum3);
+          sum4 = __SMLAD(inA2, inB2, sum4);
 
 #else
 
-          inA1 = *pInA++;
-          inB1 = *pInB++;
-          inA2 = *pInA++;
+          inA1 = *pInA;
+          inB1 = *pInB;
           sum += inA1 * inB1;
-          inB2 = *pInB++;
 
-          inA1 = *pInA++;
-          inB1 = *pInB++;
+          inA2 = pInA[1];
+          inB2 = pInB[1];
           sum += inA2 * inB2;
-          inA2 = *pInA++;
-          inB2 = *pInB++;
 
+          inA1 = pInA[2];
+          inB1 = pInB[2];
           sum += inA1 * inB1;
+
+          inA2 = pInA[3];
+          inB2 = pInB[3];
           sum += inA2 * inB2;
 
+          pInA += 4;
+          pInB += 4;
+
 #endif	/*	#ifndef UNALIGNED_SUPPORT_DISABLE	*/
 
           /* Decrement the loop counter */
@@ -330,6 +362,18 @@ arm_status arm_mat_mult_fast_q15(
         }
 
         /* process odd column samples */
+#ifndef UNALIGNED_SUPPORT_DISABLE
+        if (numColsA & 1u) {
+          inA1 = *pInA++;
+          inB1 = *pInB++;
+          inA2 = *pInA2++;
+          inB2 = *pInB2++;
+          sum  += inA1 * inB1;
+          sum2 += inA1 * inB2;
+          sum3 += inA2 * inB1;
+          sum4 += inA2 * inB2;
+        }
+#else
         colCnt = numColsA % 0x4u;
 
         while(colCnt > 0u)
@@ -339,22 +383,146 @@ arm_status arm_mat_mult_fast_q15(
 
           colCnt--;
         }
+#endif
 
         /* Saturate and store the result in the destination buffer */
-        *px = (q15_t) (sum >> 15);
-        px++;
+        *px++  = (q15_t) (sum >> 15);
+
+#ifndef UNALIGNED_SUPPORT_DISABLE
+        *px++  = (q15_t) (sum2 >> 15);
+        *px2++ = (q15_t) (sum3 >> 15);
+        *px2++ = (q15_t) (sum4 >> 15);
+        j += numRowsB * 2;
+#endif
 
         /* Decrement the column loop counter */
         col--;
 
-      } while(col > 0u);
+      }
 
       i = i + numColsA;
 
+#ifndef UNALIGNED_SUPPORT_DISABLE
+      i = i + numColsA;
+      px = px2 + (numColsB & 1u);
+      px2 = px + numColsB;
+#endif
+
       /* Decrement the row loop counter */
       row--;
 
-    } while(row > 0u);
+    }
+
+    /* Compute any remaining odd row/column below */
+
+#ifndef UNALIGNED_SUPPORT_DISABLE
+
+    /* Compute remaining output column */
+    if (numColsB & 1u) {
+
+      /* Avoid redundant computation of last element */
+      row = numRowsA & (~0x1);
+
+      /* Point to remaining unfilled column in output matrix */
+      px = pDst->pData+numColsB-1;
+      pInA = pSrcA->pData;
+
+      /* row loop */
+      while (row > 0)
+      {
+
+        /* point to last column in matrix B */
+        pInB  = pSrcBT + numRowsB*(numColsB-1);
+
+        /* Set the variable sum, that acts as accumulator, to zero */
+        sum  = 0;
+
+        /* Compute 4 columns at once */
+        colCnt = numColsA >> 2;
+
+        /* matrix multiplication */
+        while(colCnt > 0u)
+        {
+          inA1 = *__SIMD32(pInA)++;
+          inA2 = *__SIMD32(pInA)++;
+          inB1 = *__SIMD32(pInB)++;
+          inB2 = *__SIMD32(pInB)++;
+
+          sum  = __SMLAD(inA1, inB1, sum);
+          sum  = __SMLAD(inA2, inB2, sum);
+
+          /* Decrement the loop counter */
+          colCnt--;
+        }
+
+        colCnt = numColsA & 3u;
+        while(colCnt > 0u) {
+          sum += (q31_t) (*pInA++) * (*pInB++);
+          colCnt--;
+        }
+
+        /* Store the result in the destination buffer */
+        *px  = (q15_t) (sum  >> 15);
+        px += numColsB;
+
+        /* Decrement the row loop counter */
+        row--;
+      } 
+    }
+
+    /* Compute remaining output row */
+    if (numRowsA & 1u) {
+
+      /* point to last row in output matrix */
+      px = pDst->pData+(numColsB)*(numRowsA-1);
+
+      pInB  = pSrcBT;
+      col = numColsB;
+      i = 0u;
+
+      /* col loop */
+      while (col > 0)
+      {
+
+        /* point to last row in matrix A */
+        pInA = pSrcA->pData + (numRowsA-1)*numColsA;
+
+        /* Set the variable sum, that acts as accumulator, to zero */
+        sum  = 0;
+
+        /* Compute 4 columns at once */
+        colCnt = numColsA >> 2;
+
+        /* matrix multiplication */
+        while(colCnt > 0u)
+        {
+          inA1 = *__SIMD32(pInA)++;
+          inA2 = *__SIMD32(pInA)++;
+          inB1 = *__SIMD32(pInB)++;
+          inB2 = *__SIMD32(pInB)++;
+
+          sum  = __SMLAD(inA1, inB1, sum);
+          sum  = __SMLAD(inA2, inB2, sum);
+
+          /* Decrement the loop counter */
+          colCnt--;
+        }
+
+        colCnt = numColsA & 3u;
+        while(colCnt > 0u) {
+          sum += (q31_t) (*pInA++) * (*pInB++);
+          colCnt--;
+        }
+
+        /* Store the result in the destination buffer */
+        *px++  = (q15_t) (sum  >> 15);
+
+        /* Decrement the col loop counter */
+        col--;
+      }
+    }
+
+#endif	/*	#ifndef UNALIGNED_SUPPORT_DISABLE	*/
 
     /* set status as ARM_MATH_SUCCESS */
     status = ARM_MATH_SUCCESS;

+ 228 - 58
CMSIS/DSP/Source/MatrixFunctions/arm_mat_mult_fast_q31.c

@@ -85,22 +85,27 @@ arm_status arm_mat_mult_fast_q31(
   const arm_matrix_instance_q31 * pSrcB,
   arm_matrix_instance_q31 * pDst)
 {
-  q31_t *pIn1 = pSrcA->pData;                    /* input data matrix pointer A */
-  q31_t *pIn2 = pSrcB->pData;                    /* input data matrix pointer B */
   q31_t *pInA = pSrcA->pData;                    /* input data matrix pointer A */
-//  q31_t *pSrcB = pSrcB->pData;                    /* input data matrix pointer B */    
-  q31_t *pOut = pDst->pData;                     /* output data matrix pointer */
+  q31_t *pInB = pSrcB->pData;                    /* input data matrix pointer B */
   q31_t *px;                                     /* Temporary output data matrix pointer */
   q31_t sum;                                     /* Accumulator */
   uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A    */
   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
-  uint16_t col, i = 0u, j, row = numRowsA, colCnt;      /* loop counters */
+  uint32_t col, i = 0u, j, row = numRowsA, colCnt;  /* loop counters */
   arm_status status;                             /* status of matrix multiplication */
-  q31_t inA1, inA2, inA3, inA4, inB1, inB2, inB3, inB4;
+  q31_t inA1, inB1;
 
-#ifdef ARM_MATH_MATRIX_CHECK
+#ifndef ARM_MATH_CM0_FAMILY
+
+  q31_t sum2, sum3, sum4;
+  q31_t inA2, inB2;
+  q31_t *pInA2;
+  q31_t *px2;
 
+#endif
+
+#ifdef ARM_MATH_MATRIX_CHECK
 
   /* Check for matrix mismatch condition */
   if((pSrcA->numCols != pSrcB->numRows) ||
@@ -113,110 +118,275 @@ arm_status arm_mat_mult_fast_q31(
 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
 
   {
+
+    px = pDst->pData;
+
+#ifndef ARM_MATH_CM0_FAMILY
+    row = row >> 1;
+    px2 = px + numColsB;
+#endif
+
     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
     /* row loop */
-    do
+    while(row > 0u)
     {
-      /* Output pointer is set to starting address of the row being processed */
-      px = pOut + i;
 
       /* For every row wise process, the column loop counter is to be initiated */
       col = numColsB;
 
       /* For every row wise process, the pIn2 pointer is set    
        ** to the starting address of the pSrcB data */
-      pIn2 = pSrcB->pData;
+      pInB = pSrcB->pData;
 
       j = 0u;
 
+#ifndef ARM_MATH_CM0_FAMILY
+      col = col >> 1;
+#endif
+
       /* column loop */
-      do
+      while (col > 0u)
       {
         /* Set the variable sum, that acts as accumulator, to zero */
         sum = 0;
 
-        /* Initiate the pointer pIn1 to point to the starting address of pInA */
-        pIn1 = pInA;
-
-        /* Apply loop unrolling and compute 4 MACs simultaneously. */
+        /* Initiate data pointers */
+        pInA = pSrcA->pData + i;
+        pInB  = pSrcB->pData + j;
+
+#ifndef ARM_MATH_CM0_FAMILY
+        sum2 = 0;
+        sum3 = 0;
+        sum4 = 0;
+        pInA2 = pInA + numColsA;
+        colCnt = numColsA;
+#else
         colCnt = numColsA >> 2;
-
+#endif
 
         /* matrix multiplication */
         while(colCnt > 0u)
         {
+
+#ifndef ARM_MATH_CM0_FAMILY
+          inA1 = *pInA++;
+          inB1 = pInB[0];
+          inA2 = *pInA2++;
+          inB2 = pInB[1];
+          pInB += numColsB;
+
+          sum  = __SMMLA(inA1, inB1, sum);
+          sum2 = __SMMLA(inA1, inB2, sum2);
+          sum3 = __SMMLA(inA2, inB1, sum3);
+          sum4 = __SMMLA(inA2, inB2, sum4);
+#else
           /* c(m,n) = a(1,1)*b(1,1) + a(1,2) * b(2,1) + .... + a(m,p)*b(p,n) */
           /* Perform the multiply-accumulates */
-          inB1 = *pIn2;
-          pIn2 += numColsB;
+          inB1 = *pInB;
+          pInB += numColsB;
+          inA1 = pInA[0];
+          sum = __SMMLA(inA1, inB1, sum);
+
+          inB1 = *pInB;
+          pInB += numColsB;
+          inA1 = pInA[1];
+          sum = __SMMLA(inA1, inB1, sum);
+
+          inB1 = *pInB;
+          pInB += numColsB;
+          inA1 = pInA[2];
+          sum = __SMMLA(inA1, inB1, sum);
+
+          inB1 = *pInB;
+          pInB += numColsB;
+          inA1 = pInA[3];
+          sum = __SMMLA(inA1, inB1, sum);
+
+          pInA += 4u;
+#endif
+          
+          /* Decrement the loop counter */
+          colCnt--;
+        }
+
+#ifdef ARM_MATH_CM0_FAMILY
+        /* If the columns of pSrcA is not a multiple of 4, compute any remaining output samples here. */
+        colCnt = numColsA % 0x4u;
+        while(colCnt > 0u)
+        {
+          sum = __SMMLA(*pInA++, *pInB, sum);
+          pInB += numColsB;
+          colCnt--;
+        }
+        j++;
+#endif
 
-          inA1 = pIn1[0];
-          inA2 = pIn1[1];
+        /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
+        *px++  = sum << 1;
 
-          inB2 = *pIn2;
-          pIn2 += numColsB;
+#ifndef ARM_MATH_CM0_FAMILY        
+        *px++  = sum2 << 1; 
+        *px2++ = sum3 << 1;
+        *px2++ = sum4 << 1; 
+        j += 2;
+#endif
 
-          inB3 = *pIn2;
-          pIn2 += numColsB;
+        /* Decrement the column loop counter */
+        col--;
 
-          sum = (q31_t) ((((q63_t) sum << 32) + ((q63_t) inA1 * inB1)) >> 32);
-          sum = (q31_t) ((((q63_t) sum << 32) + ((q63_t) inA2 * inB2)) >> 32);
+      }
 
-          inA3 = pIn1[2];
-          inA4 = pIn1[3];
+      i = i + numColsA;
 
-          inB4 = *pIn2;
-          pIn2 += numColsB;
+#ifndef ARM_MATH_CM0_FAMILY  
+      i = i + numColsA;
+      px = px2 + (numColsB & 1u);
+      px2 = px + numColsB;
+#endif
 
-          sum = (q31_t) ((((q63_t) sum << 32) + ((q63_t) inA3 * inB3)) >> 32);
-          sum = (q31_t) ((((q63_t) sum << 32) + ((q63_t) inA4 * inB4)) >> 32);
+      /* Decrement the row loop counter */
+      row--;
 
-          pIn1 += 4u;
+    }
 
-          /* Decrement the loop counter */
-          colCnt--;
-        }
+    /* Compute any remaining odd row/column below */
 
-        /* If the columns of pSrcA is not a multiple of 4, compute any remaining output samples here.    
-         ** No loop unrolling is used. */
-        colCnt = numColsA % 0x4u;
+#ifndef ARM_MATH_CM0_FAMILY
+
+    /* Compute remaining output column */
+    if (numColsB & 1u) {
+
+      /* Avoid redundant computation of last element */
+      row = numRowsA & (~0x1);
+
+      /* Point to remaining unfilled column in output matrix */
+      px = pDst->pData+numColsB-1;
+      pInA = pSrcA->pData;
+
+      /* row loop */
+      while (row > 0)
+      {
+
+        /* point to last column in matrix B */
+        pInB  = pSrcB->pData + numColsB-1;
+
+        /* Set the variable sum, that acts as accumulator, to zero */
+        sum  = 0;
 
+        /* Compute 4 columns at once */
+        colCnt = numColsA >> 2;
+
+        /* matrix multiplication */
         while(colCnt > 0u)
         {
-          /* c(m,n) = a(1,1)*b(1,1) + a(1,2) * b(2,1) + .... + a(m,p)*b(p,n) */
-          /* Perform the multiply-accumulates */
-          sum = (q31_t) ((((q63_t) sum << 32) +
-                          ((q63_t) * pIn1++ * (*pIn2))) >> 32);
-          pIn2 += numColsB;
+          inA1 = *pInA++;
+          inA2 = *pInA++;
+          inB1 = *pInB;
+          pInB += numColsB;
+          inB2 = *pInB;
+          pInB += numColsB;
+          sum = __SMMLA(inA1, inB1, sum);
+          sum = __SMMLA(inA2, inB2, sum);
+
+          inA1 = *pInA++;
+          inA2 = *pInA++;
+          inB1 = *pInB;
+          pInB += numColsB;
+          inB2 = *pInB;
+          pInB += numColsB;
+          sum = __SMMLA(inA1, inB1, sum);
+          sum = __SMMLA(inA2, inB2, sum);
 
           /* Decrement the loop counter */
           colCnt--;
         }
 
+        colCnt = numColsA & 3u;
+        while(colCnt > 0u) {
+          sum = __SMMLA(*pInA++, *pInB, sum);
+          pInB += numColsB;
+          colCnt--;
+        }
+
         /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
-        *px++ = sum << 1;
+        *px = sum << 1;
+        px += numColsB;
 
-        /* Update the pointer pIn2 to point to the  starting address of the next column */
-        j++;
-        pIn2 = pSrcB->pData + j;
+        /* Decrement the row loop counter */
+        row--;
+      } 
+    }
 
-        /* Decrement the column loop counter */
-        col--;
+    /* Compute remaining output row */
+    if (numRowsA & 1u) {
 
-      } while(col > 0u);
+      /* point to last row in output matrix */
+      px = pDst->pData+(numColsB)*(numRowsA-1);
 
-      /* Update the pointer pInA to point to the  starting address of the next row */
-      i = i + numColsB;
-      pInA = pInA + numColsA;
+      col = numColsB;
+      i = 0u;
 
-      /* Decrement the row loop counter */
-      row--;
+      /* col loop */
+      while (col > 0)
+      {
+
+        /* point to last row in matrix A */
+        pInA = pSrcA->pData + (numRowsA-1)*numColsA;
+        pInB  = pSrcB->pData + i;
 
-    } while(row > 0u);
+        /* Set the variable sum, that acts as accumulator, to zero */
+        sum  = 0;
+
+        /* Compute 4 columns at once */
+        colCnt = numColsA >> 2;
+
+        /* matrix multiplication */
+        while(colCnt > 0u)
+        {
+          inA1 = *pInA++;
+          inA2 = *pInA++;
+          inB1 = *pInB;
+          pInB += numColsB;
+          inB2 = *pInB;
+          pInB += numColsB;
+          sum = __SMMLA(inA1, inB1, sum);
+          sum = __SMMLA(inA2, inB2, sum);
+
+          inA1 = *pInA++;
+          inA2 = *pInA++;
+          inB1 = *pInB;
+          pInB += numColsB;
+          inB2 = *pInB;
+          pInB += numColsB;
+          sum = __SMMLA(inA1, inB1, sum);
+          sum = __SMMLA(inA2, inB2, sum);
+
+          /* Decrement the loop counter */
+          colCnt--;
+        }
+
+        colCnt = numColsA & 3u;
+        while(colCnt > 0u) {
+          sum = __SMMLA(*pInA++, *pInB, sum);
+          pInB += numColsB;
+          colCnt--;
+        }
+
+        /* Saturate and store the result in the destination buffer */
+        *px++ = sum << 1;
+        i++;
+
+        /* Decrement the col loop counter */
+        col--;
+      }
+    }
+
+#endif	/*	#ifndef ARM_MATH_CM0_FAMILY	*/
 
     /* set status as ARM_MATH_SUCCESS */
     status = ARM_MATH_SUCCESS;
   }
+
   /* Return to application */
   return (status);
 }