Explorar el Código

CMSIS-DSP: Corrected issue in scalar arm_rfft_fast_f16

Christophe Favergeon hace 5 años
padre
commit
3752e622b8

+ 6 - 4
Include/dsp/transform_functions_f16.h

@@ -26,16 +26,18 @@
 #ifndef _TRANSFORM_FUNCTIONS_F16_H_
 #define _TRANSFORM_FUNCTIONS_F16_H_
 
+#include "arm_math_types_f16.h"
+#include "arm_math_memory.h"
+
+#include "dsp/none.h"
+#include "dsp/utils.h"
+
 #ifdef   __cplusplus
 extern "C"
 {
 #endif
 
-#include "arm_math_types_f16.h"
-#include "arm_math_memory.h"
 
-#include "dsp/none.h"
-#include "dsp/utils.h"
 
 #if defined(ARM_FLOAT16_SUPPORTED)
 

+ 6 - 4
Source/TransformFunctions/arm_rfft_fast_f16.c

@@ -316,7 +316,7 @@ void stage_rfft_f16(
         float16_t * p,
         float16_t * pOut)
 {
-        uint32_t  k;                                /* Loop Counter */
+        int32_t  k;                                /* Loop Counter */
         float16_t twR, twI;                         /* RFFT Twiddle coefficients */
   const float16_t * pCoeff = S->pTwiddleRFFT;       /* Points to RFFT Twiddle factors */
         float16_t *pA = p;                          /* increasing pointer */
@@ -396,7 +396,7 @@ void stage_rfft_f16(
       pA += 2;
       pB -= 2;
       k--;
-   } while (k > 0U);
+   } while (k > 0);
 }
 
 /* Prepares data for inverse cfft */
@@ -405,7 +405,7 @@ void merge_rfft_f16(
         float16_t * p,
         float16_t * pOut)
 {
-        uint32_t  k;                                /* Loop Counter */
+        int32_t  k;                                /* Loop Counter */
         float16_t twR, twI;                         /* RFFT Twiddle coefficients */
   const float16_t *pCoeff = S->pTwiddleRFFT;        /* Points to RFFT Twiddle factors */
         float16_t *pA = p;                          /* increasing pointer */
@@ -426,7 +426,7 @@ void merge_rfft_f16(
    pB  =  p + 2*k ;
    pA +=  2	   ;
 
-   while (k > 0U)
+   while (k > 0)
    {
       /* G is half of the frequency complex spectrum */
       //for k = 2:N
@@ -583,6 +583,7 @@ void arm_rfft_fast_f16(
 {
    const arm_cfft_instance_f16 * Sint = &(S->Sint);
 
+
    /* Calculation of Real FFT */
    if (ifftFlag)
    {
@@ -593,6 +594,7 @@ void arm_rfft_fast_f16(
    }
    else
    {
+
       /* Calculation of RFFT of input */
       arm_cfft_f16( Sint, p, ifftFlag, 1);
 

+ 2 - 0
Testing/Include/Benchmarks/TransformF16.h

@@ -14,6 +14,7 @@ class TransformF16:public Client::Suite
             Client::Pattern<float16_t> samples;
 
             Client::LocalPattern<float16_t> output;
+            Client::LocalPattern<float16_t> tmp;
             Client::LocalPattern<float16_t> state;
             
             int nbSamples;
@@ -23,6 +24,7 @@ class TransformF16:public Client::Suite
             float16_t *pSrc;
             float16_t *pDst;
             float16_t *pState;
+            float16_t *pTmp;
 
             arm_cfft_instance_f16 cfftInstance;
             arm_rfft_fast_instance_f16 rfftFastInstance;

+ 2 - 0
Testing/Include/Benchmarks/TransformF32.h

@@ -14,6 +14,7 @@ class TransformF32:public Client::Suite
             Client::Pattern<float32_t> samples;
 
             Client::LocalPattern<float32_t> output;
+            Client::LocalPattern<float32_t> tmp;
             Client::LocalPattern<float32_t> state;
             
             int nbSamples;
@@ -23,6 +24,7 @@ class TransformF32:public Client::Suite
             float32_t *pSrc;
             float32_t *pDst;
             float32_t *pState;
+            float32_t *pTmp;
 
             arm_cfft_instance_f32 cfftInstance;
             arm_rfft_fast_instance_f32 rfftFastInstance;

+ 8 - 3
Testing/Source/Benchmarks/TransformF16.cpp

@@ -8,7 +8,7 @@
 
     void TransformF16::test_rfft_f16()
     { 
-       arm_rfft_fast_f16(&this->rfftFastInstance, this->pSrc, this->pDst, this->ifft);
+       arm_rfft_fast_f16(&this->rfftFastInstance, this->pTmp, this->pDst, this->ifft);
     } 
 
     void TransformF16::test_cfft_radix4_f16()
@@ -45,11 +45,16 @@
           break;
 
           case TEST_RFFT_F16_2:
-            samples.reload(TransformF16::INPUTR_F16_ID,mgr,this->nbSamples);
-            output.create(this->nbSamples,TransformF16::OUT_F16_ID,mgr);
+            // Factor 2 for irfft
+            samples.reload(TransformF16::INPUTR_F16_ID,mgr,2*this->nbSamples);
+            output.create(2*this->nbSamples,TransformF16::OUT_F16_ID,mgr);
+            tmp.create(2*this->nbSamples,TransformF16::TMP_F16_ID,mgr);
 
             this->pSrc=samples.ptr();
             this->pDst=output.ptr();
+            this->pTmp=tmp.ptr();
+
+            memcpy(this->pTmp,this->pSrc,sizeof(float16_t)*this->nbSamples); 
 
             arm_rfft_fast_init_f16(&this->rfftFastInstance, this->nbSamples);
           break;

+ 7 - 3
Testing/Source/Benchmarks/TransformF32.cpp

@@ -8,7 +8,7 @@
 
     void TransformF32::test_rfft_f32()
     { 
-       arm_rfft_fast_f32(&this->rfftFastInstance, this->pSrc, this->pDst, this->ifft);
+       arm_rfft_fast_f32(&this->rfftFastInstance, this->pTmp, this->pDst, this->ifft);
     } 
 
     void TransformF32::test_dct4_f32()
@@ -54,11 +54,16 @@
           break;
 
           case TEST_RFFT_F32_2:
-            samples.reload(TransformF32::INPUTR_F32_ID,mgr,this->nbSamples);
+            // Factor 2 for rifft
+            samples.reload(TransformF32::INPUTR_F32_ID,mgr,2*this->nbSamples);
             output.create(this->nbSamples,TransformF32::OUT_F32_ID,mgr);
+            tmp.create(this->nbSamples,TransformF32::TMP_F32_ID,mgr);
 
             this->pSrc=samples.ptr();
             this->pDst=output.ptr();
+            this->pTmp=tmp.ptr();
+
+            memcpy(this->pTmp,this->pSrc,sizeof(float32_t)*this->nbSamples); 
 
             arm_rfft_fast_init_f32(&this->rfftFastInstance, this->nbSamples);
           break;
@@ -67,7 +72,6 @@
             samples.reload(TransformF32::INPUTR_F32_ID,mgr,this->nbSamples);
             output.create(this->nbSamples,TransformF32::OUT_F32_ID,mgr);
             state.create(2*this->nbSamples,TransformF32::STATE_F32_ID,mgr);
-            
 
             this->pSrc=samples.ptr();
             this->pDst=output.ptr();

+ 2 - 0
Testing/bench.txt

@@ -1536,6 +1536,8 @@ group Root {
 
                 Pattern INPUTR_F32_ID : RealInputSamples19_f32.txt 
                 Pattern INPUTC_F32_ID : ComplexInputSamples_Noisy_512_6_f32.txt 
+                
+                Output  TMP_F32_ID : Temp
                 Output  OUT_F32_ID : Output
                 Output  STATE_F32_ID : Output
 

+ 4 - 1
Testing/bench_f16.txt

@@ -555,7 +555,9 @@ group Root {
 
                 Pattern INPUTR_F16_ID : RealInputSamples19_f16.txt 
                 Pattern INPUTC_F16_ID : ComplexInputSamples_Noisy_512_6_f16.txt 
-                Output  OUT_F16_ID : Output
+                
+                Output  OUT_F16_ID : Temp
+                Output  TMP_F16_ID : Output
                 Output  STATE_F16_ID : Output
 
                 
@@ -577,6 +579,7 @@ group Root {
                   REV = [1]
                 }
 
+   
 
                 Functions {
                    Complex FFT:test_cfft_f16 -> CFFT_PARAM_ID