svm_functions.h 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. /******************************************************************************
  2. * @file svm_functions.h
  3. * @brief Public header file for CMSIS DSP Library
  4. * @version V1.9.0
  5. * @date 20. July 2020
  6. ******************************************************************************/
  7. /*
  8. * Copyright (c) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  9. *
  10. * SPDX-License-Identifier: Apache-2.0
  11. *
  12. * Licensed under the Apache License, Version 2.0 (the License); you may
  13. * not use this file except in compliance with the License.
  14. * You may obtain a copy of the License at
  15. *
  16. * www.apache.org/licenses/LICENSE-2.0
  17. *
  18. * Unless required by applicable law or agreed to in writing, software
  19. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  20. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. * See the License for the specific language governing permissions and
  22. * limitations under the License.
  23. */
  24. #ifndef _SVM_FUNCTIONS_H_
  25. #define _SVM_FUNCTIONS_H_
  26. #include "arm_math_types.h"
  27. #include "arm_math_memory.h"
  28. #include "dsp/none.h"
  29. #include "dsp/utils.h"
  30. #include "dsp/svm_defines.h"
  31. #ifdef __cplusplus
  32. extern "C"
  33. {
  34. #endif
  35. #define STEP(x) (x) <= 0 ? 0 : 1
  36. /**
  37. * @defgroup groupSVM SVM Functions
  38. * This set of functions is implementing SVM classification on 2 classes.
  39. * The training must be done from scikit-learn. The parameters can be easily
  40. * generated from the scikit-learn object. Some examples are given in
  41. * DSP/Testing/PatternGeneration/SVM.py
  42. *
  43. * If more than 2 classes are needed, the functions in this folder
  44. * will have to be used, as building blocks, to do multi-class classification.
  45. *
  46. * No multi-class classification is provided in this SVM folder.
  47. *
  48. */
  49. /**
  50. * @brief Integer exponentiation
  51. * @param[in] x value
  52. * @param[in] nb integer exponent >= 1
  53. * @return x^nb
  54. *
  55. */
  56. __STATIC_INLINE float32_t arm_exponent_f32(float32_t x, int32_t nb)
  57. {
  58. float32_t r = x;
  59. nb --;
  60. while(nb > 0)
  61. {
  62. r = r * x;
  63. nb--;
  64. }
  65. return(r);
  66. }
  67. /**
  68. * @brief Instance structure for linear SVM prediction function.
  69. */
  70. typedef struct
  71. {
  72. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  73. uint32_t vectorDimension; /**< Dimension of vector space */
  74. float32_t intercept; /**< Intercept */
  75. const float32_t *dualCoefficients; /**< Dual coefficients */
  76. const float32_t *supportVectors; /**< Support vectors */
  77. const int32_t *classes; /**< The two SVM classes */
  78. } arm_svm_linear_instance_f32;
  79. /**
  80. * @brief Instance structure for polynomial SVM prediction function.
  81. */
  82. typedef struct
  83. {
  84. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  85. uint32_t vectorDimension; /**< Dimension of vector space */
  86. float32_t intercept; /**< Intercept */
  87. const float32_t *dualCoefficients; /**< Dual coefficients */
  88. const float32_t *supportVectors; /**< Support vectors */
  89. const int32_t *classes; /**< The two SVM classes */
  90. int32_t degree; /**< Polynomial degree */
  91. float32_t coef0; /**< Polynomial constant */
  92. float32_t gamma; /**< Gamma factor */
  93. } arm_svm_polynomial_instance_f32;
  94. /**
  95. * @brief Instance structure for rbf SVM prediction function.
  96. */
  97. typedef struct
  98. {
  99. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  100. uint32_t vectorDimension; /**< Dimension of vector space */
  101. float32_t intercept; /**< Intercept */
  102. const float32_t *dualCoefficients; /**< Dual coefficients */
  103. const float32_t *supportVectors; /**< Support vectors */
  104. const int32_t *classes; /**< The two SVM classes */
  105. float32_t gamma; /**< Gamma factor */
  106. } arm_svm_rbf_instance_f32;
  107. /**
  108. * @brief Instance structure for sigmoid SVM prediction function.
  109. */
  110. typedef struct
  111. {
  112. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  113. uint32_t vectorDimension; /**< Dimension of vector space */
  114. float32_t intercept; /**< Intercept */
  115. const float32_t *dualCoefficients; /**< Dual coefficients */
  116. const float32_t *supportVectors; /**< Support vectors */
  117. const int32_t *classes; /**< The two SVM classes */
  118. float32_t coef0; /**< Independant constant */
  119. float32_t gamma; /**< Gamma factor */
  120. } arm_svm_sigmoid_instance_f32;
  121. /**
  122. * @brief SVM linear instance init function
  123. * @param[in] S Parameters for SVM functions
  124. * @param[in] nbOfSupportVectors Number of support vectors
  125. * @param[in] vectorDimension Dimension of vector space
  126. * @param[in] intercept Intercept
  127. * @param[in] dualCoefficients Array of dual coefficients
  128. * @param[in] supportVectors Array of support vectors
  129. * @param[in] classes Array of 2 classes ID
  130. * @return none.
  131. *
  132. */
  133. void arm_svm_linear_init_f32(arm_svm_linear_instance_f32 *S,
  134. uint32_t nbOfSupportVectors,
  135. uint32_t vectorDimension,
  136. float32_t intercept,
  137. const float32_t *dualCoefficients,
  138. const float32_t *supportVectors,
  139. const int32_t *classes);
  140. /**
  141. * @brief SVM linear prediction
  142. * @param[in] S Pointer to an instance of the linear SVM structure.
  143. * @param[in] in Pointer to input vector
  144. * @param[out] pResult Decision value
  145. * @return none.
  146. *
  147. */
  148. void arm_svm_linear_predict_f32(const arm_svm_linear_instance_f32 *S,
  149. const float32_t * in,
  150. int32_t * pResult);
  151. /**
  152. * @brief SVM polynomial instance init function
  153. * @param[in] S points to an instance of the polynomial SVM structure.
  154. * @param[in] nbOfSupportVectors Number of support vectors
  155. * @param[in] vectorDimension Dimension of vector space
  156. * @param[in] intercept Intercept
  157. * @param[in] dualCoefficients Array of dual coefficients
  158. * @param[in] supportVectors Array of support vectors
  159. * @param[in] classes Array of 2 classes ID
  160. * @param[in] degree Polynomial degree
  161. * @param[in] coef0 coeff0 (scikit-learn terminology)
  162. * @param[in] gamma gamma (scikit-learn terminology)
  163. * @return none.
  164. *
  165. */
  166. void arm_svm_polynomial_init_f32(arm_svm_polynomial_instance_f32 *S,
  167. uint32_t nbOfSupportVectors,
  168. uint32_t vectorDimension,
  169. float32_t intercept,
  170. const float32_t *dualCoefficients,
  171. const float32_t *supportVectors,
  172. const int32_t *classes,
  173. int32_t degree,
  174. float32_t coef0,
  175. float32_t gamma
  176. );
  177. /**
  178. * @brief SVM polynomial prediction
  179. * @param[in] S Pointer to an instance of the polynomial SVM structure.
  180. * @param[in] in Pointer to input vector
  181. * @param[out] pResult Decision value
  182. * @return none.
  183. *
  184. */
  185. void arm_svm_polynomial_predict_f32(const arm_svm_polynomial_instance_f32 *S,
  186. const float32_t * in,
  187. int32_t * pResult);
  188. /**
  189. * @brief SVM radial basis function instance init function
  190. * @param[in] S points to an instance of the polynomial SVM structure.
  191. * @param[in] nbOfSupportVectors Number of support vectors
  192. * @param[in] vectorDimension Dimension of vector space
  193. * @param[in] intercept Intercept
  194. * @param[in] dualCoefficients Array of dual coefficients
  195. * @param[in] supportVectors Array of support vectors
  196. * @param[in] classes Array of 2 classes ID
  197. * @param[in] gamma gamma (scikit-learn terminology)
  198. * @return none.
  199. *
  200. */
  201. void arm_svm_rbf_init_f32(arm_svm_rbf_instance_f32 *S,
  202. uint32_t nbOfSupportVectors,
  203. uint32_t vectorDimension,
  204. float32_t intercept,
  205. const float32_t *dualCoefficients,
  206. const float32_t *supportVectors,
  207. const int32_t *classes,
  208. float32_t gamma
  209. );
  210. /**
  211. * @brief SVM rbf prediction
  212. * @param[in] S Pointer to an instance of the rbf SVM structure.
  213. * @param[in] in Pointer to input vector
  214. * @param[out] pResult decision value
  215. * @return none.
  216. *
  217. */
  218. void arm_svm_rbf_predict_f32(const arm_svm_rbf_instance_f32 *S,
  219. const float32_t * in,
  220. int32_t * pResult);
  221. /**
  222. * @brief SVM sigmoid instance init function
  223. * @param[in] S points to an instance of the rbf SVM structure.
  224. * @param[in] nbOfSupportVectors Number of support vectors
  225. * @param[in] vectorDimension Dimension of vector space
  226. * @param[in] intercept Intercept
  227. * @param[in] dualCoefficients Array of dual coefficients
  228. * @param[in] supportVectors Array of support vectors
  229. * @param[in] classes Array of 2 classes ID
  230. * @param[in] coef0 coeff0 (scikit-learn terminology)
  231. * @param[in] gamma gamma (scikit-learn terminology)
  232. * @return none.
  233. *
  234. */
  235. void arm_svm_sigmoid_init_f32(arm_svm_sigmoid_instance_f32 *S,
  236. uint32_t nbOfSupportVectors,
  237. uint32_t vectorDimension,
  238. float32_t intercept,
  239. const float32_t *dualCoefficients,
  240. const float32_t *supportVectors,
  241. const int32_t *classes,
  242. float32_t coef0,
  243. float32_t gamma
  244. );
  245. /**
  246. * @brief SVM sigmoid prediction
  247. * @param[in] S Pointer to an instance of the rbf SVM structure.
  248. * @param[in] in Pointer to input vector
  249. * @param[out] pResult Decision value
  250. * @return none.
  251. *
  252. */
  253. void arm_svm_sigmoid_predict_f32(const arm_svm_sigmoid_instance_f32 *S,
  254. const float32_t * in,
  255. int32_t * pResult);
  256. #ifdef __cplusplus
  257. }
  258. #endif
  259. #endif /* ifndef _SVM_FUNCTIONS_H_ */