svm_functions_f16.h 9.8 KB


  1. /******************************************************************************
  2. * @file svm_functions_f16.h
  3. * @brief Public header file for CMSIS DSP Library
  4. * @version V1.9.0
  5. * @date 23 April 2021
  6. * Target Processor: Cortex-M and Cortex-A cores
  7. ******************************************************************************/
  8. /*
  9. * Copyright (c) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  10. *
  11. * SPDX-License-Identifier: Apache-2.0
  12. *
  13. * Licensed under the Apache License, Version 2.0 (the License); you may
  14. * not use this file except in compliance with the License.
  15. * You may obtain a copy of the License at
  16. *
  17. * www.apache.org/licenses/LICENSE-2.0
  18. *
  19. * Unless required by applicable law or agreed to in writing, software
  20. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  21. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. * See the License for the specific language governing permissions and
  23. * limitations under the License.
  24. */
  25. #ifndef _SVM_FUNCTIONS_F16_H_
  26. #define _SVM_FUNCTIONS_F16_H_
  27. #include "arm_math_types_f16.h"
  28. #include "arm_math_memory.h"
  29. #include "dsp/none.h"
  30. #include "dsp/utils.h"
  31. #include "dsp/svm_defines.h"
  32. #ifdef __cplusplus
  33. extern "C"
  34. {
  35. #endif
  36. #if defined(ARM_FLOAT16_SUPPORTED)
  37. #define STEP(x) (x) <= 0 ? 0 : 1
  38. /**
  39. * @defgroup groupSVM SVM Functions
  40. * This set of functions is implementing SVM classification on 2 classes.
  41. * The training must be done from scikit-learn. The parameters can be easily
  42. * generated from the scikit-learn object. Some examples are given in
  43. * DSP/Testing/PatternGeneration/SVM.py
  44. *
  45. * If more than 2 classes are needed, the functions in this folder
  46. * will have to be used, as building blocks, to do multi-class classification.
  47. *
  48. * No multi-class classification is provided in this SVM folder.
  49. *
  50. */
  51. /**
  52. * @brief Integer exponentiation
  53. * @param[in] x value
  54. * @param[in] nb integer exponent >= 1
  55. * @return x^nb
  56. *
  57. */
  58. __STATIC_INLINE float16_t arm_exponent_f16(float16_t x, int32_t nb)
  59. {
  60. float16_t r = x;
  61. nb --;
  62. while(nb > 0)
  63. {
  64. r = r * x;
  65. nb--;
  66. }
  67. return(r);
  68. }
  69. /**
  70. * @brief Instance structure for linear SVM prediction function.
  71. */
  72. typedef struct
  73. {
  74. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  75. uint32_t vectorDimension; /**< Dimension of vector space */
  76. float16_t intercept; /**< Intercept */
  77. const float16_t *dualCoefficients; /**< Dual coefficients */
  78. const float16_t *supportVectors; /**< Support vectors */
  79. const int32_t *classes; /**< The two SVM classes */
  80. } arm_svm_linear_instance_f16;
  81. /**
  82. * @brief Instance structure for polynomial SVM prediction function.
  83. */
  84. typedef struct
  85. {
  86. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  87. uint32_t vectorDimension; /**< Dimension of vector space */
  88. float16_t intercept; /**< Intercept */
  89. const float16_t *dualCoefficients; /**< Dual coefficients */
  90. const float16_t *supportVectors; /**< Support vectors */
  91. const int32_t *classes; /**< The two SVM classes */
  92. int32_t degree; /**< Polynomial degree */
  93. float16_t coef0; /**< Polynomial constant */
  94. float16_t gamma; /**< Gamma factor */
  95. } arm_svm_polynomial_instance_f16;
  96. /**
  97. * @brief Instance structure for rbf SVM prediction function.
  98. */
  99. typedef struct
  100. {
  101. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  102. uint32_t vectorDimension; /**< Dimension of vector space */
  103. float16_t intercept; /**< Intercept */
  104. const float16_t *dualCoefficients; /**< Dual coefficients */
  105. const float16_t *supportVectors; /**< Support vectors */
  106. const int32_t *classes; /**< The two SVM classes */
  107. float16_t gamma; /**< Gamma factor */
  108. } arm_svm_rbf_instance_f16;
  109. /**
  110. * @brief Instance structure for sigmoid SVM prediction function.
  111. */
  112. typedef struct
  113. {
  114. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  115. uint32_t vectorDimension; /**< Dimension of vector space */
  116. float16_t intercept; /**< Intercept */
  117. const float16_t *dualCoefficients; /**< Dual coefficients */
  118. const float16_t *supportVectors; /**< Support vectors */
  119. const int32_t *classes; /**< The two SVM classes */
  120. float16_t coef0; /**< Independent constant */
  121. float16_t gamma; /**< Gamma factor */
  122. } arm_svm_sigmoid_instance_f16;
  123. /**
  124. * @brief SVM linear instance init function
  125. * @param[in] S Parameters for SVM functions
  126. * @param[in] nbOfSupportVectors Number of support vectors
  127. * @param[in] vectorDimension Dimension of vector space
  128. * @param[in] intercept Intercept
  129. * @param[in] dualCoefficients Array of dual coefficients
  130. * @param[in] supportVectors Array of support vectors
  131. * @param[in] classes Array of 2 classes ID
  132. * @return none.
  133. *
  134. */
  135. void arm_svm_linear_init_f16(arm_svm_linear_instance_f16 *S,
  136. uint32_t nbOfSupportVectors,
  137. uint32_t vectorDimension,
  138. float16_t intercept,
  139. const float16_t *dualCoefficients,
  140. const float16_t *supportVectors,
  141. const int32_t *classes);
  142. /**
  143. * @brief SVM linear prediction
  144. * @param[in] S Pointer to an instance of the linear SVM structure.
  145. * @param[in] in Pointer to input vector
  146. * @param[out] pResult Decision value
  147. * @return none.
  148. *
  149. */
  150. void arm_svm_linear_predict_f16(const arm_svm_linear_instance_f16 *S,
  151. const float16_t * in,
  152. int32_t * pResult);
  153. /**
  154. * @brief SVM polynomial instance init function
  155. * @param[in] S points to an instance of the polynomial SVM structure.
  156. * @param[in] nbOfSupportVectors Number of support vectors
  157. * @param[in] vectorDimension Dimension of vector space
  158. * @param[in] intercept Intercept
  159. * @param[in] dualCoefficients Array of dual coefficients
  160. * @param[in] supportVectors Array of support vectors
  161. * @param[in] classes Array of 2 classes ID
  162. * @param[in] degree Polynomial degree
  163. * @param[in] coef0 coeff0 (scikit-learn terminology)
  164. * @param[in] gamma gamma (scikit-learn terminology)
  165. * @return none.
  166. *
  167. */
  168. void arm_svm_polynomial_init_f16(arm_svm_polynomial_instance_f16 *S,
  169. uint32_t nbOfSupportVectors,
  170. uint32_t vectorDimension,
  171. float16_t intercept,
  172. const float16_t *dualCoefficients,
  173. const float16_t *supportVectors,
  174. const int32_t *classes,
  175. int32_t degree,
  176. float16_t coef0,
  177. float16_t gamma
  178. );
  179. /**
  180. * @brief SVM polynomial prediction
  181. * @param[in] S Pointer to an instance of the polynomial SVM structure.
  182. * @param[in] in Pointer to input vector
  183. * @param[out] pResult Decision value
  184. * @return none.
  185. *
  186. */
  187. void arm_svm_polynomial_predict_f16(const arm_svm_polynomial_instance_f16 *S,
  188. const float16_t * in,
  189. int32_t * pResult);
  190. /**
  191. * @brief SVM radial basis function instance init function
  192. * @param[in] S points to an instance of the polynomial SVM structure.
  193. * @param[in] nbOfSupportVectors Number of support vectors
  194. * @param[in] vectorDimension Dimension of vector space
  195. * @param[in] intercept Intercept
  196. * @param[in] dualCoefficients Array of dual coefficients
  197. * @param[in] supportVectors Array of support vectors
  198. * @param[in] classes Array of 2 classes ID
  199. * @param[in] gamma gamma (scikit-learn terminology)
  200. * @return none.
  201. *
  202. */
  203. void arm_svm_rbf_init_f16(arm_svm_rbf_instance_f16 *S,
  204. uint32_t nbOfSupportVectors,
  205. uint32_t vectorDimension,
  206. float16_t intercept,
  207. const float16_t *dualCoefficients,
  208. const float16_t *supportVectors,
  209. const int32_t *classes,
  210. float16_t gamma
  211. );
  212. /**
  213. * @brief SVM rbf prediction
  214. * @param[in] S Pointer to an instance of the rbf SVM structure.
  215. * @param[in] in Pointer to input vector
  216. * @param[out] pResult decision value
  217. * @return none.
  218. *
  219. */
  220. void arm_svm_rbf_predict_f16(const arm_svm_rbf_instance_f16 *S,
  221. const float16_t * in,
  222. int32_t * pResult);
  223. /**
  224. * @brief SVM sigmoid instance init function
  225. * @param[in] S points to an instance of the rbf SVM structure.
  226. * @param[in] nbOfSupportVectors Number of support vectors
  227. * @param[in] vectorDimension Dimension of vector space
  228. * @param[in] intercept Intercept
  229. * @param[in] dualCoefficients Array of dual coefficients
  230. * @param[in] supportVectors Array of support vectors
  231. * @param[in] classes Array of 2 classes ID
  232. * @param[in] coef0 coeff0 (scikit-learn terminology)
  233. * @param[in] gamma gamma (scikit-learn terminology)
  234. * @return none.
  235. *
  236. */
  237. void arm_svm_sigmoid_init_f16(arm_svm_sigmoid_instance_f16 *S,
  238. uint32_t nbOfSupportVectors,
  239. uint32_t vectorDimension,
  240. float16_t intercept,
  241. const float16_t *dualCoefficients,
  242. const float16_t *supportVectors,
  243. const int32_t *classes,
  244. float16_t coef0,
  245. float16_t gamma
  246. );
  247. /**
  248. * @brief SVM sigmoid prediction
  249. * @param[in] S Pointer to an instance of the rbf SVM structure.
  250. * @param[in] in Pointer to input vector
  251. * @param[out] pResult Decision value
  252. * @return none.
  253. *
  254. */
  255. void arm_svm_sigmoid_predict_f16(const arm_svm_sigmoid_instance_f16 *S,
  256. const float16_t * in,
  257. int32_t * pResult);
  258. #endif /*defined(ARM_FLOAT16_SUPPORTED)*/
  259. #ifdef __cplusplus
  260. }
  261. #endif
  262. #endif /* ifndef _SVM_FUNCTIONS_F16_H_ */