tinymaix.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef __TINYMAIX_H
  13. #define __TINYMAIX_H
  14. #include <stdint.h>
  15. #include <stdio.h>
  16. #include <stdlib.h>
  17. #include <string.h>
  18. #define TM_MDL_INT8 0
  19. #define TM_MDL_INT16 1
  20. #define TM_MDL_FP32 2
  21. #define TM_MDL_FP16 3
  22. #define TM_MDL_FP8_143 4 //experimental
  23. #define TM_MDL_FP8_152 5 //experimental
  24. #include "tm_port.h"
  25. /******************************* MARCO ************************************/
  26. #define TM_MDL_MAGIC 'XIAM' //mdl magic sign
  27. #define TM_ALIGN_SIZE (8) //8 byte align
  28. #define TM_ALIGN(addr) ((((size_t)(addr))+(TM_ALIGN_SIZE-1))/TM_ALIGN_SIZE*TM_ALIGN_SIZE)
  29. #define TM_MATP(mat,y,x,ch) ((mat)->data + ((y)*(mat)->w + (x))*(mat)->c + (ch))
  30. //HWC
  31. #if TM_MDL_TYPE == TM_MDL_INT8
  32. typedef int8_t mtype_t; //mat data type
  33. typedef int8_t wtype_t; //weight data type
  34. typedef int32_t btype_t; //bias data type
  35. typedef int32_t sumtype_t; //sum data type
  36. typedef int32_t zptype_t; //zeropoint data type
  37. #define UINT2INT_SHIFT (0)
  38. #elif TM_MDL_TYPE == TM_MDL_INT16
  39. typedef int16_t mtype_t; //mat data type
  40. typedef int16_t wtype_t; //weight data type
  41. typedef int32_t btype_t; //bias data type
  42. typedef int32_t sumtype_t; //sum data type
  43. typedef int32_t zptype_t; //zeropoint data type
  44. #define UINT2INT_SHIFT (8)
  45. #elif TM_MDL_TYPE == TM_MDL_FP32
  46. typedef float mtype_t; //mat data type
  47. typedef float wtype_t; //weight data type
  48. typedef float btype_t; //bias data type
  49. typedef float sumtype_t; //sum data type
  50. typedef float zptype_t; //zeropoint data type
  51. #elif TM_MDL_TYPE == TM_MDL_FP16
  52. #if TM_ARCH != TM_ARCH_RV64V
  53. #error "only support RV64V's float16!"
  54. #endif
  55. #include <riscv_vector.h>
  56. typedef float16_t mtype_t; //mat data type
  57. typedef float16_t wtype_t; //weight data type
  58. typedef float16_t btype_t; //bias data type
  59. typedef float16_t sumtype_t; //sum data type
  60. typedef float16_t zptype_t; //zeropoint data type
  61. #elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
  62. #if TM_ARCH != TM_ARCH_CPU
  63. #error "only support CPU simulation now!"
  64. #endif
  65. typedef uint8_t mtype_t; //mat data type
  66. typedef uint8_t wtype_t; //weight data type
  67. typedef uint8_t btype_t; //bias data type
  68. typedef float sumtype_t; //sum data type
  69. typedef float zptype_t; //zeropoint data type
  70. #else
  71. #error "Not support this MDL_TYPE!"
  72. #endif
  73. #if TM_MDL_TYPE == TM_MDL_FP8_143
  74. #define TM_FP8_SCNT (1)
  75. #define TM_FP8_ECNT (4)
  76. #define TM_FP8_MCNT (3)
  77. #define TM_FP8_BIAS (9)
  78. #elif TM_MDL_TYPE == TM_MDL_FP8_152
  79. #define TM_FP8_SCNT (1)
  80. #define TM_FP8_ECNT (5)
  81. #define TM_FP8_MCNT (2)
  82. #define TM_FP8_BIAS (15)
  83. #endif
  84. typedef float sctype_t;
  85. #define TM_FASTSCALE_SHIFT (8)
  86. /******************************* ENUM ************************************/
  87. typedef enum{
  88. TM_OK = 0,
  89. TM_ERR= 1,
  90. TM_ERR_MAGIC = 2,
  91. TM_ERR_UNSUPPORT = 3,
  92. TM_ERR_OOM = 4,
  93. TM_ERR_LAYERTYPE = 5,
  94. TM_ERR_DIMS = 6,
  95. TM_ERR_TODO = 7,
  96. TM_ERR_MDLTYPE = 8,
  97. TM_ERR_KSIZE = 9,
  98. }tm_err_t;
  99. typedef enum{
  100. TML_CONV2D = 0,
  101. TML_GAP = 1,
  102. TML_FC = 2,
  103. TML_SOFTMAX = 3,
  104. TML_RESHAPE = 4,
  105. TML_DWCONV2D = 5,
  106. TML_ADD = 6,
  107. TML_MAXCNT ,
  108. }tm_layer_type_t;
  109. typedef enum{
  110. TM_PAD_VALID = 0,
  111. TM_PAD_SAME = 1,
  112. }tm_pad_type_t;
  113. typedef enum{
  114. TM_ACT_NONE = 0,
  115. TM_ACT_RELU = 1,
  116. TM_ACT_RELU1 = 2,
  117. TM_ACT_RELU6 = 3,
  118. TM_ACT_TANH = 4,
  119. TM_ACT_SIGNBIT= 5,
  120. TM_ACT_MAXCNT ,
  121. }tm_act_type_t;
  122. typedef enum {
  123. TMPP_NONE = 0,
  124. TMPP_FP2INT = 1, //user own fp buf -> int input buf
  125. TMPP_UINT2INT = 2, //int8: cvt in place; int16: can't cvt in place
  126. TMPP_UINT2FP01 = 3, // u8/255.0
  127. TMPP_UINT2FPN11= 4, // (u8-128)/128
  128. TMPP_UINT2DTYPE= 5, //uint8 to fp16,fp8
  129. TMPP_MAXCNT,
  130. }tm_pp_t;
  131. /******************************* STRUCT ************************************/
  132. //mdlbin in flash
  133. typedef struct{
  134. uint32_t magic; //"MAIX"
  135. uint8_t mdl_type; //0 int8, 1 int16, 2 fp32,
  136. uint8_t out_deq; //0 don't dequant out; 1 dequant out
  137. uint16_t input_cnt; //only support 1 yet
  138. uint16_t output_cnt; //only support 1 yet
  139. uint16_t layer_cnt;
  140. uint32_t buf_size; //main buf size for middle result = pingpong+keep
  141. uint32_t sub_size; //pingpong buf size;
  142. uint16_t in_dims[4]; //0:dims; 1:dim0; 2:dim1; 3:dim2
  143. uint16_t out_dims[4];
  144. uint8_t reserve[28]; //reserve for future
  145. uint8_t layers_body[0];//oft 64 here
  146. }tm_mdlbin_t;
  147. //mdl meta data in ram
  148. typedef struct{
  149. tm_mdlbin_t* b; //bin
  150. void* cb; //Layer callback
  151. uint8_t* buf; //main buf addr
  152. uint8_t* subbuf; //sub buf addr
  153. uint16_t main_alloc; //is main buf alloc or static
  154. uint16_t layer_i; //current layer index
  155. uint8_t* layer_body; //current layer body addr
  156. }tm_mdl_t;
  157. //dims==3, hwc
  158. //dims==2, 1wc
  159. //dims==1, 11c
  160. typedef struct{
  161. uint16_t dims;
  162. uint16_t h;
  163. uint16_t w;
  164. uint16_t c;
  165. union {
  166. mtype_t* data;
  167. float* dataf;
  168. };
  169. }tm_mat_t;
  170. /******************************* LAYER STRUCT ************************************/
  171. typedef struct{ //48byte
  172. uint16_t type; //layer type
  173. uint16_t is_out; //is output
  174. uint32_t size; //8 byte align size for this layer
  175. uint32_t in_oft; //input oft in main buf
  176. uint32_t out_oft; //output oft in main buf
  177. uint16_t in_dims[4]; //0:dims; 1:dim0; 2:dim1; 3:dim2
  178. uint16_t out_dims[4];
  179. //following unit not used in fp32 mode
  180. sctype_t in_s; //input scale,
  181. zptype_t in_zp; //input zeropoint
  182. sctype_t out_s; //output scale
  183. zptype_t out_zp; //output zeropoint
  184. //note: real = scale*(q-zeropoint)
  185. }tml_head_t;
  186. typedef struct{
  187. tml_head_t h;
  188. uint8_t kernel_w;
  189. uint8_t kernel_h;
  190. uint8_t stride_w;
  191. uint8_t stride_h;
  192. uint8_t dilation_w;
  193. uint8_t dilation_h;
  194. uint16_t act; //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
  195. uint8_t pad[4]; //top,bottom,left,right
  196. uint32_t depth_mul; //depth_multiplier: if conv2d,=0; else: >=1
  197. uint32_t reserve; //for 8byte align
  198. uint32_t ws_oft; //weight scale oft from this layer start
  199. //skip bias scale: bias_scale = weight_scale*in_scale
  200. uint32_t w_oft; //weight oft from this layer start
  201. uint32_t b_oft; //bias oft from this layer start
  202. //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
  203. // fused in advance (when convert model)
  204. }tml_conv2d_dw_t; //compatible with conv2d and dwconv2d
  205. typedef struct{
  206. tml_head_t h;
  207. }tml_gap_t;
  208. typedef struct{
  209. tml_head_t h;
  210. uint32_t ws_oft; //weight scale oft from this layer start
  211. uint32_t w_oft; //weight oft from this layer start
  212. uint32_t b_oft; //bias oft from this layer start
  213. uint32_t reserve; //for 8byte align
  214. }tml_fc_t;
  215. typedef struct{
  216. tml_head_t h;
  217. }tml_softmax_t;
  218. typedef struct{
  219. tml_head_t h;
  220. }tml_reshape_t;
  221. typedef struct{
  222. tml_head_t h;
  223. uint8_t kernel_w;
  224. uint8_t kernel_h;
  225. uint8_t stride_w;
  226. uint8_t stride_h;
  227. uint8_t dilation_w;
  228. uint8_t dilation_h;
  229. uint16_t act; //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
  230. uint8_t pad[4]; //top,bottom,left,right
  231. uint32_t ws_oft; //weight scale oft from this layer start
  232. //skip bias scale: bias_scale = weight_scale*in_scale
  233. uint32_t w_oft; //weight oft from this layer start
  234. uint32_t b_oft; //bias oft from this layer start
  235. //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
  236. // fused in advance (when convert model)
  237. }tml_dwconv2d_t;
  238. typedef struct{
  239. tml_head_t h;
  240. uint32_t in_oft1;
  241. sctype_t in_s1; //input scale,
  242. zptype_t in_zp1; //input zeropoint
  243. uint32_t reserve; //align8
  244. }tml_add_t;
  245. /******************************* TYPE ************************************/
  246. typedef tm_err_t (*tml_stat_t)(tml_head_t* layer, tm_mat_t* in, tm_mat_t* out);
  247. typedef tm_err_t (*tm_cb_t)(tm_mdl_t* mdl, tml_head_t* lh);
  248. /******************************* GLOBAL VARIABLE ************************************/
  249. /******************************* MODEL FUNCTION ************************************/
  250. tm_err_t tm_load (tm_mdl_t* mdl, const uint8_t* bin, uint8_t*buf, tm_cb_t cb, tm_mat_t* in); //load model
  251. void tm_unload(tm_mdl_t* mdl); //remove model
  252. tm_err_t tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out); //preprocess input data
  253. tm_err_t tm_run (tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out); //run model
  254. /******************************* LAYER FUNCTION ************************************/
  255. tm_err_t tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
  256. int kw, int kh, int sx, int sy, int dx, int dy, int act, \
  257. int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
  258. sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
  259. tm_err_t tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
  260. tm_err_t tml_fc(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
  261. sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
  262. tm_err_t tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
  263. tm_err_t tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
  264. tm_err_t tml_add(tm_mat_t* in0, tm_mat_t* in1, tm_mat_t* out, \
  265. sctype_t in_s0, zptype_t in_zp0, sctype_t in_s1, zptype_t in_zp1, sctype_t out_s, zptype_t out_zp);
  266. /******************************* STAT FUNCTION ************************************/
  267. #if TM_ENABLE_STAT
  268. tm_err_t tm_stat(tm_mdlbin_t* mdl); //stat model
  269. #endif
  270. /******************************* UTILS FUNCTION ************************************/
  271. uint8_t TM_WEAK tm_fp32to8(float fp32);
  272. float TM_WEAK tm_fp8to32(uint8_t fp8);
  273. /******************************* UTILS ************************************/
  274. #define TML_GET_INPUT(mdl,lh) ((mtype_t*)((mdl)->buf + (lh)->in_oft))
  275. #define TML_GET_OUTPUT(mdl,lh) ((mtype_t*)((mdl)->buf + (lh)->out_oft))
  276. #if (TM_MDL_TYPE == TM_MDL_INT8)||(TM_MDL_TYPE == TM_MDL_INT16)
  277. #define TML_DEQUANT(lh, x) (((sumtype_t)(x)-((lh)->out_zp))*((lh)->out_s))
  278. #define TM_DEQUANT(i8,s,zp) (((sumtype_t)(i8)-(zp))*(s))
  279. #define TM_QUANT(fp32,s,zp) ((mtype_t)((fp32)/(s)+zp))
  280. #elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
  281. #define TML_DEQUANT(lh, x) (tm_fp8to32(x))
  282. #else //FP32,FP16
  283. #define TML_DEQUANT(lh, x) ((float)(x))
  284. #define TM_DEQUANT(x,s,zp) (x)
  285. #define TM_QUANT(x,s,zp) (x)
  286. #endif
  287. /******************************* LOCAL MATH FUNCTION ************************************/
  288. #if TM_LOCAL_MATH
  289. //http://www.machinedlearnings.com/2011/06/fast-approximate-logarithm-exponential.html
  290. static inline float _exp(float x) {
  291. float p = 1.442695040f * x;
  292. uint32_t i = 0;
  293. uint32_t sign = (i >> 31);
  294. int w = (int) p;
  295. float z = p - (float) w + (float) sign;
  296. union {
  297. uint32_t i;
  298. float f;
  299. } v = {.i = (uint32_t) ((1 << 23) * (p + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z))};
  300. return v.f;
  301. }
  302. #define tm_exp _exp //maybe some arch have exp acceleration, use macro in arch_xxx.h to reload it
  303. #else
  304. #define tm_exp exp
  305. #endif
  306. #endif