arm_barycenter_f16.c 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_barycenter_f16.c
  4. * Description: Barycenter
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/support_functions_f16.h"
  29. #if defined(ARM_FLOAT16_SUPPORTED)
  30. #include <limits.h>
  31. #include <math.h>
  32. /**
  33. @ingroup groupSupport
  34. */
  35. /**
  36. @defgroup barycenter Barycenter
  37. Barycenter of weighted vectors
  38. */
  39. /**
  40. @addtogroup barycenter
  41. @{
  42. */
  43. /**
  44. * @brief Barycenter
  45. *
  46. *
  47. * @param[in] *in List of vectors
  48. * @param[in] *weights Weights of the vectors
  49. * @param[out] *out Barycenter
  50. * @param[in] nbVectors Number of vectors
  51. * @param[in] vecDim Dimension of space (vector dimension)
  52. * @return None
  53. *
  54. */
  55. #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
  56. void arm_barycenter_f16(const float16_t *in,
  57. const float16_t *weights,
  58. float16_t *out,
  59. uint32_t nbVectors,
  60. uint32_t vecDim)
  61. {
  62. const float16_t *pIn, *pW;
  63. const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
  64. float16_t *pOut;
  65. uint32_t blkCntVector, blkCntSample;
  66. float16_t accum, w;
  67. blkCntVector = nbVectors;
  68. blkCntSample = vecDim;
  69. accum = 0.0f;
  70. pW = weights;
  71. pIn = in;
  72. arm_fill_f16(0.0f, out, vecDim);
  73. /* Sum */
  74. pIn1 = pIn;
  75. pIn2 = pIn1 + vecDim;
  76. pIn3 = pIn2 + vecDim;
  77. pIn4 = pIn3 + vecDim;
  78. blkCntVector = nbVectors >> 2;
  79. while (blkCntVector > 0)
  80. {
  81. f16x8_t outV, inV1, inV2, inV3, inV4;
  82. float16_t w1, w2, w3, w4;
  83. pOut = out;
  84. w1 = *pW++;
  85. w2 = *pW++;
  86. w3 = *pW++;
  87. w4 = *pW++;
  88. accum += w1 + w2 + w3 + w4;
  89. blkCntSample = vecDim >> 3;
  90. while (blkCntSample > 0) {
  91. outV = vld1q((const float16_t *) pOut);
  92. inV1 = vld1q(pIn1);
  93. inV2 = vld1q(pIn2);
  94. inV3 = vld1q(pIn3);
  95. inV4 = vld1q(pIn4);
  96. outV = vfmaq(outV, inV1, w1);
  97. outV = vfmaq(outV, inV2, w2);
  98. outV = vfmaq(outV, inV3, w3);
  99. outV = vfmaq(outV, inV4, w4);
  100. vst1q(pOut, outV);
  101. pOut += 8;
  102. pIn1 += 8;
  103. pIn2 += 8;
  104. pIn3 += 8;
  105. pIn4 += 8;
  106. blkCntSample--;
  107. }
  108. blkCntSample = vecDim & 7;
  109. while (blkCntSample > 0) {
  110. *pOut = *pOut + *pIn1++ * w1;
  111. *pOut = *pOut + *pIn2++ * w2;
  112. *pOut = *pOut + *pIn3++ * w3;
  113. *pOut = *pOut + *pIn4++ * w4;
  114. pOut++;
  115. blkCntSample--;
  116. }
  117. pIn1 += 3 * vecDim;
  118. pIn2 += 3 * vecDim;
  119. pIn3 += 3 * vecDim;
  120. pIn4 += 3 * vecDim;
  121. blkCntVector--;
  122. }
  123. pIn = pIn1;
  124. blkCntVector = nbVectors & 3;
  125. while (blkCntVector > 0)
  126. {
  127. f16x8_t inV, outV;
  128. pOut = out;
  129. w = *pW++;
  130. accum += w;
  131. blkCntSample = vecDim >> 3;
  132. while (blkCntSample > 0)
  133. {
  134. outV = vld1q_f16(pOut);
  135. inV = vld1q_f16(pIn);
  136. outV = vfmaq(outV, inV, w);
  137. vst1q_f16(pOut, outV);
  138. pOut += 8;
  139. pIn += 8;
  140. blkCntSample--;
  141. }
  142. blkCntSample = vecDim & 7;
  143. while (blkCntSample > 0)
  144. {
  145. *pOut = *pOut + *pIn++ * w;
  146. pOut++;
  147. blkCntSample--;
  148. }
  149. blkCntVector--;
  150. }
  151. /* Normalize */
  152. pOut = out;
  153. accum = 1.0f / accum;
  154. blkCntSample = vecDim >> 3;
  155. while (blkCntSample > 0)
  156. {
  157. f16x8_t tmp;
  158. tmp = vld1q((const float16_t *) pOut);
  159. tmp = vmulq(tmp, accum);
  160. vst1q(pOut, tmp);
  161. pOut += 8;
  162. blkCntSample--;
  163. }
  164. blkCntSample = vecDim & 7;
  165. while (blkCntSample > 0)
  166. {
  167. *pOut = *pOut * accum;
  168. pOut++;
  169. blkCntSample--;
  170. }
  171. }
  172. #else
  173. void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
  174. {
  175. const float16_t *pIn,*pW;
  176. float16_t *pOut;
  177. uint32_t blkCntVector,blkCntSample;
  178. float16_t accum, w;
  179. blkCntVector = nbVectors;
  180. blkCntSample = vecDim;
  181. accum = 0.0f;
  182. pW = weights;
  183. pIn = in;
  184. /* Set counters to 0 */
  185. blkCntSample = vecDim;
  186. pOut = out;
  187. while(blkCntSample > 0)
  188. {
  189. *pOut = 0.0f;
  190. pOut++;
  191. blkCntSample--;
  192. }
  193. /* Sum */
  194. while(blkCntVector > 0)
  195. {
  196. pOut = out;
  197. w = *pW++;
  198. accum += w;
  199. blkCntSample = vecDim;
  200. while(blkCntSample > 0)
  201. {
  202. *pOut = *pOut + *pIn++ * w;
  203. pOut++;
  204. blkCntSample--;
  205. }
  206. blkCntVector--;
  207. }
  208. /* Normalize */
  209. blkCntSample = vecDim;
  210. pOut = out;
  211. while(blkCntSample > 0)
  212. {
  213. *pOut = *pOut / accum;
  214. pOut++;
  215. blkCntSample--;
  216. }
  217. }
  218. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  219. /**
  220. * @} end of barycenter group
  221. */
  222. #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */