arm_barycenter_f16.c 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_barycenter_f16.c
  4. * Description: Barycenter
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/support_functions_f16.h"
  29. #if defined(ARM_FLOAT16_SUPPORTED)
  30. #include <limits.h>
  31. #include <math.h>
  32. /**
  33. @ingroup groupSupport
  34. */
  35. /**
  36. @defgroup barycenter Barycenter
  37. Barycenter of weighted vectors
  38. */
  39. /**
  40. @addtogroup barycenter
  41. @{
  42. */
  43. /**
  44. * @brief Barycenter
  45. *
  46. *
  47. * @param[in] *in List of vectors
  48. * @param[in] *weights Weights of the vectors
  49. * @param[out] *out Barycenter
  50. * @param[in] nbVectors Number of vectors
  51. * @param[in] vecDim Dimension of space (vector dimension)
  52. *
  53. */
  54. #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
  55. void arm_barycenter_f16(const float16_t *in,
  56. const float16_t *weights,
  57. float16_t *out,
  58. uint32_t nbVectors,
  59. uint32_t vecDim)
  60. {
  61. const float16_t *pIn, *pW;
  62. const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
  63. float16_t *pOut;
  64. uint32_t blkCntVector, blkCntSample;
  65. float16_t accum, w;
  66. blkCntVector = nbVectors;
  67. blkCntSample = vecDim;
  68. accum = 0.0f;
  69. pW = weights;
  70. pIn = in;
  71. arm_fill_f16(0.0f, out, vecDim);
  72. /* Sum */
  73. pIn1 = pIn;
  74. pIn2 = pIn1 + vecDim;
  75. pIn3 = pIn2 + vecDim;
  76. pIn4 = pIn3 + vecDim;
  77. blkCntVector = nbVectors >> 2;
  78. while (blkCntVector > 0)
  79. {
  80. f16x8_t outV, inV1, inV2, inV3, inV4;
  81. float16_t w1, w2, w3, w4;
  82. pOut = out;
  83. w1 = *pW++;
  84. w2 = *pW++;
  85. w3 = *pW++;
  86. w4 = *pW++;
  87. accum += (_Float16)w1 + (_Float16)w2 + (_Float16)w3 + (_Float16)w4;
  88. blkCntSample = vecDim >> 3;
  89. while (blkCntSample > 0) {
  90. outV = vld1q((const float16_t *) pOut);
  91. inV1 = vld1q(pIn1);
  92. inV2 = vld1q(pIn2);
  93. inV3 = vld1q(pIn3);
  94. inV4 = vld1q(pIn4);
  95. outV = vfmaq(outV, inV1, w1);
  96. outV = vfmaq(outV, inV2, w2);
  97. outV = vfmaq(outV, inV3, w3);
  98. outV = vfmaq(outV, inV4, w4);
  99. vst1q(pOut, outV);
  100. pOut += 8;
  101. pIn1 += 8;
  102. pIn2 += 8;
  103. pIn3 += 8;
  104. pIn4 += 8;
  105. blkCntSample--;
  106. }
  107. blkCntSample = vecDim & 7;
  108. while (blkCntSample > 0) {
  109. *pOut = (_Float16)*pOut + (_Float16)*pIn1++ * (_Float16)w1;
  110. *pOut = (_Float16)*pOut + (_Float16)*pIn2++ * (_Float16)w2;
  111. *pOut = (_Float16)*pOut + (_Float16)*pIn3++ * (_Float16)w3;
  112. *pOut = (_Float16)*pOut + (_Float16)*pIn4++ * (_Float16)w4;
  113. pOut++;
  114. blkCntSample--;
  115. }
  116. pIn1 += 3 * vecDim;
  117. pIn2 += 3 * vecDim;
  118. pIn3 += 3 * vecDim;
  119. pIn4 += 3 * vecDim;
  120. blkCntVector--;
  121. }
  122. pIn = pIn1;
  123. blkCntVector = nbVectors & 3;
  124. while (blkCntVector > 0)
  125. {
  126. f16x8_t inV, outV;
  127. pOut = out;
  128. w = *pW++;
  129. accum += (_Float16)w;
  130. blkCntSample = vecDim >> 3;
  131. while (blkCntSample > 0)
  132. {
  133. outV = vld1q_f16(pOut);
  134. inV = vld1q_f16(pIn);
  135. outV = vfmaq(outV, inV, w);
  136. vst1q_f16(pOut, outV);
  137. pOut += 8;
  138. pIn += 8;
  139. blkCntSample--;
  140. }
  141. blkCntSample = vecDim & 7;
  142. while (blkCntSample > 0)
  143. {
  144. *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
  145. pOut++;
  146. blkCntSample--;
  147. }
  148. blkCntVector--;
  149. }
  150. /* Normalize */
  151. pOut = out;
  152. accum = 1.0f16 / (_Float16)accum;
  153. blkCntSample = vecDim >> 3;
  154. while (blkCntSample > 0)
  155. {
  156. f16x8_t tmp;
  157. tmp = vld1q((const float16_t *) pOut);
  158. tmp = vmulq(tmp, accum);
  159. vst1q(pOut, tmp);
  160. pOut += 8;
  161. blkCntSample--;
  162. }
  163. blkCntSample = vecDim & 7;
  164. while (blkCntSample > 0)
  165. {
  166. *pOut = (_Float16)*pOut * (_Float16)accum;
  167. pOut++;
  168. blkCntSample--;
  169. }
  170. }
  171. #else
  172. void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
  173. {
  174. const float16_t *pIn,*pW;
  175. float16_t *pOut;
  176. uint32_t blkCntVector,blkCntSample;
  177. float16_t accum, w;
  178. blkCntVector = nbVectors;
  179. blkCntSample = vecDim;
  180. accum = 0.0f16;
  181. pW = weights;
  182. pIn = in;
  183. /* Set counters to 0 */
  184. blkCntSample = vecDim;
  185. pOut = out;
  186. while(blkCntSample > 0)
  187. {
  188. *pOut = 0.0f16;
  189. pOut++;
  190. blkCntSample--;
  191. }
  192. /* Sum */
  193. while(blkCntVector > 0)
  194. {
  195. pOut = out;
  196. w = *pW++;
  197. accum += (_Float16)w;
  198. blkCntSample = vecDim;
  199. while(blkCntSample > 0)
  200. {
  201. *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
  202. pOut++;
  203. blkCntSample--;
  204. }
  205. blkCntVector--;
  206. }
  207. /* Normalize */
  208. blkCntSample = vecDim;
  209. pOut = out;
  210. while(blkCntSample > 0)
  211. {
  212. *pOut = (_Float16)*pOut / (_Float16)accum;
  213. pOut++;
  214. blkCntSample--;
  215. }
  216. }
  217. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  218. /**
  219. * @} end of barycenter group
  220. */
  221. #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */