svm_functions.h 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. /******************************************************************************
  2. * @file svm_functions.h
  3. * @brief Public header file for NMSIS DSP Library
  4. * @version V1.10.0
  5. * @date 08 July 2021
  6. * Target Processor: RISC-V Cores
  7. ******************************************************************************/
  8. /*
  9. * Copyright (c) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  10. * Copyright (c) 2019 Nuclei Limited. All rights reserved.
  11. *
  12. * SPDX-License-Identifier: Apache-2.0
  13. *
  14. * Licensed under the Apache License, Version 2.0 (the License); you may
  15. * not use this file except in compliance with the License.
  16. * You may obtain a copy of the License at
  17. *
  18. * www.apache.org/licenses/LICENSE-2.0
  19. *
  20. * Unless required by applicable law or agreed to in writing, software
  21. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  22. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. * See the License for the specific language governing permissions and
  24. * limitations under the License.
  25. */
  26. #ifndef SVM_FUNCTIONS_H_
  27. #define SVM_FUNCTIONS_H_
  28. #include "riscv_math_types.h"
  29. #include "riscv_math_memory.h"
  30. #include "dsp/none.h"
  31. #include "dsp/utils.h"
  32. #include "dsp/svm_defines.h"
  33. #ifdef __cplusplus
  34. extern "C"
  35. {
  36. #endif
  37. #define STEP(x) (x) <= 0 ? 0 : 1
  38. /**
  39. * @defgroup groupSVM SVM Functions
  40. * This set of functions is implementing SVM classification on 2 classes.
  41. * The training must be done from scikit-learn. The parameters can be easily
  42. * generated from the scikit-learn object. Some examples are given in
  43. * DSP/Testing/PatternGeneration/SVM.py
  44. *
  45. * If more than 2 classes are needed, the functions in this folder
  46. * will have to be used, as building blocks, to do multi-class classification.
  47. *
  48. * No multi-class classification is provided in this SVM folder.
  49. *
  50. */
  51. /**
  52. * @brief Integer exponentiation
  53. * @param[in] x value
  54. * @param[in] nb integer exponent >= 1
  55. * @return x^nb
  56. */
  57. __STATIC_INLINE float32_t riscv_exponent_f32(float32_t x, int32_t nb)
  58. {
  59. float32_t r = x;
  60. nb --;
  61. while(nb > 0)
  62. {
  63. r = r * x;
  64. nb--;
  65. }
  66. return(r);
  67. }
  68. /**
  69. * @brief Instance structure for linear SVM prediction function.
  70. */
  71. typedef struct
  72. {
  73. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  74. uint32_t vectorDimension; /**< Dimension of vector space */
  75. float32_t intercept; /**< Intercept */
  76. const float32_t *dualCoefficients; /**< Dual coefficients */
  77. const float32_t *supportVectors; /**< Support vectors */
  78. const int32_t *classes; /**< The two SVM classes */
  79. } riscv_svm_linear_instance_f32;
  80. /**
  81. * @brief Instance structure for polynomial SVM prediction function.
  82. */
  83. typedef struct
  84. {
  85. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  86. uint32_t vectorDimension; /**< Dimension of vector space */
  87. float32_t intercept; /**< Intercept */
  88. const float32_t *dualCoefficients; /**< Dual coefficients */
  89. const float32_t *supportVectors; /**< Support vectors */
  90. const int32_t *classes; /**< The two SVM classes */
  91. int32_t degree; /**< Polynomial degree */
  92. float32_t coef0; /**< Polynomial constant */
  93. float32_t gamma; /**< Gamma factor */
  94. } riscv_svm_polynomial_instance_f32;
  95. /**
  96. * @brief Instance structure for rbf SVM prediction function.
  97. */
  98. typedef struct
  99. {
  100. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  101. uint32_t vectorDimension; /**< Dimension of vector space */
  102. float32_t intercept; /**< Intercept */
  103. const float32_t *dualCoefficients; /**< Dual coefficients */
  104. const float32_t *supportVectors; /**< Support vectors */
  105. const int32_t *classes; /**< The two SVM classes */
  106. float32_t gamma; /**< Gamma factor */
  107. } riscv_svm_rbf_instance_f32;
  108. /**
  109. * @brief Instance structure for sigmoid SVM prediction function.
  110. */
  111. typedef struct
  112. {
  113. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  114. uint32_t vectorDimension; /**< Dimension of vector space */
  115. float32_t intercept; /**< Intercept */
  116. const float32_t *dualCoefficients; /**< Dual coefficients */
  117. const float32_t *supportVectors; /**< Support vectors */
  118. const int32_t *classes; /**< The two SVM classes */
  119. float32_t coef0; /**< Independent constant */
  120. float32_t gamma; /**< Gamma factor */
  121. } riscv_svm_sigmoid_instance_f32;
  122. /**
  123. * @brief SVM linear instance init function
  124. * @param[in] S Parameters for SVM functions
  125. * @param[in] nbOfSupportVectors Number of support vectors
  126. * @param[in] vectorDimension Dimension of vector space
  127. * @param[in] intercept Intercept
  128. * @param[in] dualCoefficients Array of dual coefficients
  129. * @param[in] supportVectors Array of support vectors
  130. * @param[in] classes Array of 2 classes ID
  131. */
  132. void riscv_svm_linear_init_f32(riscv_svm_linear_instance_f32 *S,
  133. uint32_t nbOfSupportVectors,
  134. uint32_t vectorDimension,
  135. float32_t intercept,
  136. const float32_t *dualCoefficients,
  137. const float32_t *supportVectors,
  138. const int32_t *classes);
  139. /**
  140. * @brief SVM linear prediction
  141. * @param[in] S Pointer to an instance of the linear SVM structure.
  142. * @param[in] in Pointer to input vector
  143. * @param[out] pResult Decision value
  144. */
  145. void riscv_svm_linear_predict_f32(const riscv_svm_linear_instance_f32 *S,
  146. const float32_t * in,
  147. int32_t * pResult);
  148. /**
  149. * @brief SVM polynomial instance init function
  150. * @param[in] S points to an instance of the polynomial SVM structure.
  151. * @param[in] nbOfSupportVectors Number of support vectors
  152. * @param[in] vectorDimension Dimension of vector space
  153. * @param[in] intercept Intercept
  154. * @param[in] dualCoefficients Array of dual coefficients
  155. * @param[in] supportVectors Array of support vectors
  156. * @param[in] classes Array of 2 classes ID
  157. * @param[in] degree Polynomial degree
  158. * @param[in] coef0 coeff0 (scikit-learn terminology)
  159. * @param[in] gamma gamma (scikit-learn terminology)
  160. */
  161. void riscv_svm_polynomial_init_f32(riscv_svm_polynomial_instance_f32 *S,
  162. uint32_t nbOfSupportVectors,
  163. uint32_t vectorDimension,
  164. float32_t intercept,
  165. const float32_t *dualCoefficients,
  166. const float32_t *supportVectors,
  167. const int32_t *classes,
  168. int32_t degree,
  169. float32_t coef0,
  170. float32_t gamma
  171. );
  172. /**
  173. * @brief SVM polynomial prediction
  174. * @param[in] S Pointer to an instance of the polynomial SVM structure.
  175. * @param[in] in Pointer to input vector
  176. * @param[out] pResult Decision value
  177. */
  178. void riscv_svm_polynomial_predict_f32(const riscv_svm_polynomial_instance_f32 *S,
  179. const float32_t * in,
  180. int32_t * pResult);
  181. /**
  182. * @brief SVM radial basis function instance init function
  183. * @param[in] S points to an instance of the polynomial SVM structure.
  184. * @param[in] nbOfSupportVectors Number of support vectors
  185. * @param[in] vectorDimension Dimension of vector space
  186. * @param[in] intercept Intercept
  187. * @param[in] dualCoefficients Array of dual coefficients
  188. * @param[in] supportVectors Array of support vectors
  189. * @param[in] classes Array of 2 classes ID
  190. * @param[in] gamma gamma (scikit-learn terminology)
  191. */
  192. void riscv_svm_rbf_init_f32(riscv_svm_rbf_instance_f32 *S,
  193. uint32_t nbOfSupportVectors,
  194. uint32_t vectorDimension,
  195. float32_t intercept,
  196. const float32_t *dualCoefficients,
  197. const float32_t *supportVectors,
  198. const int32_t *classes,
  199. float32_t gamma
  200. );
  201. /**
  202. * @brief SVM rbf prediction
  203. * @param[in] S Pointer to an instance of the rbf SVM structure.
  204. * @param[in] in Pointer to input vector
  205. * @param[out] pResult decision value
  206. */
  207. void riscv_svm_rbf_predict_f32(const riscv_svm_rbf_instance_f32 *S,
  208. const float32_t * in,
  209. int32_t * pResult);
  210. /**
  211. * @brief SVM sigmoid instance init function
  212. * @param[in] S points to an instance of the rbf SVM structure.
  213. * @param[in] nbOfSupportVectors Number of support vectors
  214. * @param[in] vectorDimension Dimension of vector space
  215. * @param[in] intercept Intercept
  216. * @param[in] dualCoefficients Array of dual coefficients
  217. * @param[in] supportVectors Array of support vectors
  218. * @param[in] classes Array of 2 classes ID
  219. * @param[in] coef0 coeff0 (scikit-learn terminology)
  220. * @param[in] gamma gamma (scikit-learn terminology)
  221. */
  222. void riscv_svm_sigmoid_init_f32(riscv_svm_sigmoid_instance_f32 *S,
  223. uint32_t nbOfSupportVectors,
  224. uint32_t vectorDimension,
  225. float32_t intercept,
  226. const float32_t *dualCoefficients,
  227. const float32_t *supportVectors,
  228. const int32_t *classes,
  229. float32_t coef0,
  230. float32_t gamma
  231. );
  232. /**
  233. * @brief SVM sigmoid prediction
  234. * @param[in] S Pointer to an instance of the rbf SVM structure.
  235. * @param[in] in Pointer to input vector
  236. * @param[out] pResult Decision value
  237. */
  238. void riscv_svm_sigmoid_predict_f32(const riscv_svm_sigmoid_instance_f32 *S,
  239. const float32_t * in,
  240. int32_t * pResult);
  241. #ifdef __cplusplus
  242. }
  243. #endif
  244. #endif /* ifndef _SVM_FUNCTIONS_H_ */