svm_functions_f16.h 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. /******************************************************************************
  2. * @file svm_functions_f16.h
  3. * @brief Public header file for NMSIS DSP Library
  4. * @version V1.10.0
  5. * @date 08 July 2021
  6. * Target Processor: RISC-V Cores
  7. ******************************************************************************/
  8. /*
  9. * Copyright (c) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  10. * Copyright (c) 2019 Nuclei Limited. All rights reserved.
  11. *
  12. * SPDX-License-Identifier: Apache-2.0
  13. *
  14. * Licensed under the Apache License, Version 2.0 (the License); you may
  15. * not use this file except in compliance with the License.
  16. * You may obtain a copy of the License at
  17. *
  18. * www.apache.org/licenses/LICENSE-2.0
  19. *
  20. * Unless required by applicable law or agreed to in writing, software
  21. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  22. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. * See the License for the specific language governing permissions and
  24. * limitations under the License.
  25. */
  26. #ifndef SVM_FUNCTIONS_F16_H_
  27. #define SVM_FUNCTIONS_F16_H_
  28. #include "riscv_math_types_f16.h"
  29. #include "riscv_math_memory.h"
  30. #include "dsp/none.h"
  31. #include "dsp/utils.h"
  32. #include "dsp/svm_defines.h"
  33. #ifdef __cplusplus
  34. extern "C"
  35. {
  36. #endif
  37. #if defined(RISCV_FLOAT16_SUPPORTED)
  38. #define STEP(x) (x) <= 0 ? 0 : 1
  39. /**
  40. * @defgroup groupSVM SVM Functions
  41. * This set of functions is implementing SVM classification on 2 classes.
  42. * The training must be done from scikit-learn. The parameters can be easily
  43. * generated from the scikit-learn object. Some examples are given in
  44. * DSP/Testing/PatternGeneration/SVM.py
  45. *
  46. * If more than 2 classes are needed, the functions in this folder
  47. * will have to be used, as building blocks, to do multi-class classification.
  48. *
  49. * No multi-class classification is provided in this SVM folder.
  50. *
  51. */
  52. /**
  53. * @brief Instance structure for linear SVM prediction function.
  54. */
  55. typedef struct
  56. {
  57. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  58. uint32_t vectorDimension; /**< Dimension of vector space */
  59. float16_t intercept; /**< Intercept */
  60. const float16_t *dualCoefficients; /**< Dual coefficients */
  61. const float16_t *supportVectors; /**< Support vectors */
  62. const int32_t *classes; /**< The two SVM classes */
  63. } riscv_svm_linear_instance_f16;
  64. /**
  65. * @brief Instance structure for polynomial SVM prediction function.
  66. */
  67. typedef struct
  68. {
  69. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  70. uint32_t vectorDimension; /**< Dimension of vector space */
  71. float16_t intercept; /**< Intercept */
  72. const float16_t *dualCoefficients; /**< Dual coefficients */
  73. const float16_t *supportVectors; /**< Support vectors */
  74. const int32_t *classes; /**< The two SVM classes */
  75. int32_t degree; /**< Polynomial degree */
  76. float16_t coef0; /**< Polynomial constant */
  77. float16_t gamma; /**< Gamma factor */
  78. } riscv_svm_polynomial_instance_f16;
  79. /**
  80. * @brief Instance structure for rbf SVM prediction function.
  81. */
  82. typedef struct
  83. {
  84. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  85. uint32_t vectorDimension; /**< Dimension of vector space */
  86. float16_t intercept; /**< Intercept */
  87. const float16_t *dualCoefficients; /**< Dual coefficients */
  88. const float16_t *supportVectors; /**< Support vectors */
  89. const int32_t *classes; /**< The two SVM classes */
  90. float16_t gamma; /**< Gamma factor */
  91. } riscv_svm_rbf_instance_f16;
  92. /**
  93. * @brief Instance structure for sigmoid SVM prediction function.
  94. */
  95. typedef struct
  96. {
  97. uint32_t nbOfSupportVectors; /**< Number of support vectors */
  98. uint32_t vectorDimension; /**< Dimension of vector space */
  99. float16_t intercept; /**< Intercept */
  100. const float16_t *dualCoefficients; /**< Dual coefficients */
  101. const float16_t *supportVectors; /**< Support vectors */
  102. const int32_t *classes; /**< The two SVM classes */
  103. float16_t coef0; /**< Independent constant */
  104. float16_t gamma; /**< Gamma factor */
  105. } riscv_svm_sigmoid_instance_f16;
  106. /**
  107. * @brief SVM linear instance init function
  108. * @param[in] S Parameters for SVM functions
  109. * @param[in] nbOfSupportVectors Number of support vectors
  110. * @param[in] vectorDimension Dimension of vector space
  111. * @param[in] intercept Intercept
  112. * @param[in] dualCoefficients Array of dual coefficients
  113. * @param[in] supportVectors Array of support vectors
  114. * @param[in] classes Array of 2 classes ID
  115. */
  116. void riscv_svm_linear_init_f16(riscv_svm_linear_instance_f16 *S,
  117. uint32_t nbOfSupportVectors,
  118. uint32_t vectorDimension,
  119. float16_t intercept,
  120. const float16_t *dualCoefficients,
  121. const float16_t *supportVectors,
  122. const int32_t *classes);
  123. /**
  124. * @brief SVM linear prediction
  125. * @param[in] S Pointer to an instance of the linear SVM structure.
  126. * @param[in] in Pointer to input vector
  127. * @param[out] pResult Decision value
  128. */
  129. void riscv_svm_linear_predict_f16(const riscv_svm_linear_instance_f16 *S,
  130. const float16_t * in,
  131. int32_t * pResult);
  132. /**
  133. * @brief SVM polynomial instance init function
  134. * @param[in] S points to an instance of the polynomial SVM structure.
  135. * @param[in] nbOfSupportVectors Number of support vectors
  136. * @param[in] vectorDimension Dimension of vector space
  137. * @param[in] intercept Intercept
  138. * @param[in] dualCoefficients Array of dual coefficients
  139. * @param[in] supportVectors Array of support vectors
  140. * @param[in] classes Array of 2 classes ID
  141. * @param[in] degree Polynomial degree
  142. * @param[in] coef0 coeff0 (scikit-learn terminology)
  143. * @param[in] gamma gamma (scikit-learn terminology)
  144. */
  145. void riscv_svm_polynomial_init_f16(riscv_svm_polynomial_instance_f16 *S,
  146. uint32_t nbOfSupportVectors,
  147. uint32_t vectorDimension,
  148. float16_t intercept,
  149. const float16_t *dualCoefficients,
  150. const float16_t *supportVectors,
  151. const int32_t *classes,
  152. int32_t degree,
  153. float16_t coef0,
  154. float16_t gamma
  155. );
  156. /**
  157. * @brief SVM polynomial prediction
  158. * @param[in] S Pointer to an instance of the polynomial SVM structure.
  159. * @param[in] in Pointer to input vector
  160. * @param[out] pResult Decision value
  161. */
  162. void riscv_svm_polynomial_predict_f16(const riscv_svm_polynomial_instance_f16 *S,
  163. const float16_t * in,
  164. int32_t * pResult);
  165. /**
  166. * @brief SVM radial basis function instance init function
  167. * @param[in] S points to an instance of the polynomial SVM structure.
  168. * @param[in] nbOfSupportVectors Number of support vectors
  169. * @param[in] vectorDimension Dimension of vector space
  170. * @param[in] intercept Intercept
  171. * @param[in] dualCoefficients Array of dual coefficients
  172. * @param[in] supportVectors Array of support vectors
  173. * @param[in] classes Array of 2 classes ID
  174. * @param[in] gamma gamma (scikit-learn terminology)
  175. */
  176. void riscv_svm_rbf_init_f16(riscv_svm_rbf_instance_f16 *S,
  177. uint32_t nbOfSupportVectors,
  178. uint32_t vectorDimension,
  179. float16_t intercept,
  180. const float16_t *dualCoefficients,
  181. const float16_t *supportVectors,
  182. const int32_t *classes,
  183. float16_t gamma
  184. );
  185. /**
  186. * @brief SVM rbf prediction
  187. * @param[in] S Pointer to an instance of the rbf SVM structure.
  188. * @param[in] in Pointer to input vector
  189. * @param[out] pResult decision value
  190. */
  191. void riscv_svm_rbf_predict_f16(const riscv_svm_rbf_instance_f16 *S,
  192. const float16_t * in,
  193. int32_t * pResult);
  194. /**
  195. * @brief SVM sigmoid instance init function
  196. * @param[in] S points to an instance of the rbf SVM structure.
  197. * @param[in] nbOfSupportVectors Number of support vectors
  198. * @param[in] vectorDimension Dimension of vector space
  199. * @param[in] intercept Intercept
  200. * @param[in] dualCoefficients Array of dual coefficients
  201. * @param[in] supportVectors Array of support vectors
  202. * @param[in] classes Array of 2 classes ID
  203. * @param[in] coef0 coeff0 (scikit-learn terminology)
  204. * @param[in] gamma gamma (scikit-learn terminology)
  205. */
  206. void riscv_svm_sigmoid_init_f16(riscv_svm_sigmoid_instance_f16 *S,
  207. uint32_t nbOfSupportVectors,
  208. uint32_t vectorDimension,
  209. float16_t intercept,
  210. const float16_t *dualCoefficients,
  211. const float16_t *supportVectors,
  212. const int32_t *classes,
  213. float16_t coef0,
  214. float16_t gamma
  215. );
  216. /**
  217. * @brief SVM sigmoid prediction
  218. * @param[in] S Pointer to an instance of the rbf SVM structure.
  219. * @param[in] in Pointer to input vector
  220. * @param[out] pResult Decision value
  221. */
  222. void riscv_svm_sigmoid_predict_f16(const riscv_svm_sigmoid_instance_f16 *S,
  223. const float16_t * in,
  224. int32_t * pResult);
  225. #endif /*defined(RISCV_FLOAT16_SUPPORTED)*/
  226. #ifdef __cplusplus
  227. }
  228. #endif
  229. #endif /* ifndef _SVM_FUNCTIONS_F16_H_ */