| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- /* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef __TINYMAIX_H
- #define __TINYMAIX_H
- #include <stdint.h>
- #include <stdio.h>
- #include <stdlib.h>
- #include <string.h>
- #define TM_MDL_INT8 0
- #define TM_MDL_INT16 1
- #define TM_MDL_FP32 2
- #define TM_MDL_FP16 3
- #define TM_MDL_FP8_143 4 //experimental
- #define TM_MDL_FP8_152 5 //experimental
- #include "tm_port.h"
- /******************************* MARCO ************************************/
- #define TM_MDL_MAGIC 'XIAM' //mdl magic sign
- #define TM_ALIGN_SIZE (8) //8 byte align
- #define TM_ALIGN(addr) ((((size_t)(addr))+(TM_ALIGN_SIZE-1))/TM_ALIGN_SIZE*TM_ALIGN_SIZE)
- #define TM_MATP(mat,y,x,ch) ((mat)->data + ((y)*(mat)->w + (x))*(mat)->c + (ch))
- //HWC
- #if TM_MDL_TYPE == TM_MDL_INT8
- typedef int8_t mtype_t; //mat data type
- typedef int8_t wtype_t; //weight data type
- typedef int32_t btype_t; //bias data type
- typedef int32_t sumtype_t; //sum data type
- typedef int32_t zptype_t; //zeropoint data type
- #define UINT2INT_SHIFT (0)
- #elif TM_MDL_TYPE == TM_MDL_INT16
- typedef int16_t mtype_t; //mat data type
- typedef int16_t wtype_t; //weight data type
- typedef int32_t btype_t; //bias data type
- typedef int32_t sumtype_t; //sum data type
- typedef int32_t zptype_t; //zeropoint data type
- #define UINT2INT_SHIFT (8)
- #elif TM_MDL_TYPE == TM_MDL_FP32
- typedef float mtype_t; //mat data type
- typedef float wtype_t; //weight data type
- typedef float btype_t; //bias data type
- typedef float sumtype_t; //sum data type
- typedef float zptype_t; //zeropoint data type
- #elif TM_MDL_TYPE == TM_MDL_FP16
- #if TM_ARCH != TM_ARCH_RV64V
- #error "only support RV64V's float16!"
- #endif
- #include <riscv_vector.h>
- typedef float16_t mtype_t; //mat data type
- typedef float16_t wtype_t; //weight data type
- typedef float16_t btype_t; //bias data type
- typedef float16_t sumtype_t; //sum data type
- typedef float16_t zptype_t; //zeropoint data type
- #elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
- #if TM_ARCH != TM_ARCH_CPU
- #error "only support CPU simulation now!"
- #endif
- typedef uint8_t mtype_t; //mat data type
- typedef uint8_t wtype_t; //weight data type
- typedef uint8_t btype_t; //bias data type
- typedef float sumtype_t; //sum data type
- typedef float zptype_t; //zeropoint data type
- #else
- #error "Not support this MDL_TYPE!"
- #endif
- #if TM_MDL_TYPE == TM_MDL_FP8_143
- #define TM_FP8_SCNT (1)
- #define TM_FP8_ECNT (4)
- #define TM_FP8_MCNT (3)
- #define TM_FP8_BIAS (9)
- #elif TM_MDL_TYPE == TM_MDL_FP8_152
- #define TM_FP8_SCNT (1)
- #define TM_FP8_ECNT (5)
- #define TM_FP8_MCNT (2)
- #define TM_FP8_BIAS (15)
- #endif
- typedef float sctype_t;
- #define TM_FASTSCALE_SHIFT (8)
- /******************************* ENUM ************************************/
- typedef enum{
- TM_OK = 0,
- TM_ERR= 1,
- TM_ERR_MAGIC = 2,
- TM_ERR_UNSUPPORT = 3,
- TM_ERR_OOM = 4,
- TM_ERR_LAYERTYPE = 5,
- TM_ERR_DIMS = 6,
- TM_ERR_TODO = 7,
- TM_ERR_MDLTYPE = 8,
- TM_ERR_KSIZE = 9,
- }tm_err_t;
- typedef enum{
- TML_CONV2D = 0,
- TML_GAP = 1,
- TML_FC = 2,
- TML_SOFTMAX = 3,
- TML_RESHAPE = 4,
- TML_DWCONV2D = 5,
- TML_ADD = 6,
- TML_MAXCNT ,
- }tm_layer_type_t;
- typedef enum{
- TM_PAD_VALID = 0,
- TM_PAD_SAME = 1,
- }tm_pad_type_t;
- typedef enum{
- TM_ACT_NONE = 0,
- TM_ACT_RELU = 1,
- TM_ACT_RELU1 = 2,
- TM_ACT_RELU6 = 3,
- TM_ACT_TANH = 4,
- TM_ACT_SIGNBIT= 5,
- TM_ACT_MAXCNT ,
- }tm_act_type_t;
- typedef enum {
- TMPP_NONE = 0,
- TMPP_FP2INT = 1, //user own fp buf -> int input buf
- TMPP_UINT2INT = 2, //int8: cvt in place; int16: can't cvt in place
- TMPP_UINT2FP01 = 3, // u8/255.0
- TMPP_UINT2FPN11= 4, // (u8-128)/128
- TMPP_UINT2DTYPE= 5, //uint8 to fp16,fp8
- TMPP_MAXCNT,
- }tm_pp_t;
- /******************************* STRUCT ************************************/
- //mdlbin in flash
- typedef struct{
- uint32_t magic; //"MAIX"
- uint8_t mdl_type; //0 int8, 1 int16, 2 fp32,
- uint8_t out_deq; //0 don't dequant out; 1 dequant out
- uint16_t input_cnt; //only support 1 yet
- uint16_t output_cnt; //only support 1 yet
- uint16_t layer_cnt;
- uint32_t buf_size; //main buf size for middle result = pingpong+keep
- uint32_t sub_size; //pingpong buf size;
- uint16_t in_dims[4]; //0:dims; 1:dim0; 2:dim1; 3:dim2
- uint16_t out_dims[4];
- uint8_t reserve[28]; //reserve for future
- uint8_t layers_body[0];//oft 64 here
- }tm_mdlbin_t;
- //mdl meta data in ram
- typedef struct{
- tm_mdlbin_t* b; //bin
- void* cb; //Layer callback
- uint8_t* buf; //main buf addr
- uint8_t* subbuf; //sub buf addr
- uint16_t main_alloc; //is main buf alloc or static
- uint16_t layer_i; //current layer index
- uint8_t* layer_body; //current layer body addr
- }tm_mdl_t;
- //dims==3, hwc
- //dims==2, 1wc
- //dims==1, 11c
- typedef struct{
- uint16_t dims;
- uint16_t h;
- uint16_t w;
- uint16_t c;
- union {
- mtype_t* data;
- float* dataf;
- };
- }tm_mat_t;
- /******************************* LAYER STRUCT ************************************/
- typedef struct{ //48byte
- uint16_t type; //layer type
- uint16_t is_out; //is output
- uint32_t size; //8 byte align size for this layer
- uint32_t in_oft; //input oft in main buf
- uint32_t out_oft; //output oft in main buf
- uint16_t in_dims[4]; //0:dims; 1:dim0; 2:dim1; 3:dim2
- uint16_t out_dims[4];
- //following unit not used in fp32 mode
- sctype_t in_s; //input scale,
- zptype_t in_zp; //input zeropoint
- sctype_t out_s; //output scale
- zptype_t out_zp; //output zeropoint
- //note: real = scale*(q-zeropoint)
- }tml_head_t;
- typedef struct{
- tml_head_t h;
- uint8_t kernel_w;
- uint8_t kernel_h;
- uint8_t stride_w;
- uint8_t stride_h;
-
- uint8_t dilation_w;
- uint8_t dilation_h;
- uint16_t act; //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
-
- uint8_t pad[4]; //top,bottom,left,right
- uint32_t depth_mul; //depth_multiplier: if conv2d,=0; else: >=1
- uint32_t reserve; //for 8byte align
-
- uint32_t ws_oft; //weight scale oft from this layer start
- //skip bias scale: bias_scale = weight_scale*in_scale
- uint32_t w_oft; //weight oft from this layer start
- uint32_t b_oft; //bias oft from this layer start
- //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
- // fused in advance (when convert model)
- }tml_conv2d_dw_t; //compatible with conv2d and dwconv2d
- typedef struct{
- tml_head_t h;
- }tml_gap_t;
- typedef struct{
- tml_head_t h;
- uint32_t ws_oft; //weight scale oft from this layer start
- uint32_t w_oft; //weight oft from this layer start
- uint32_t b_oft; //bias oft from this layer start
- uint32_t reserve; //for 8byte align
- }tml_fc_t;
- typedef struct{
- tml_head_t h;
- }tml_softmax_t;
- typedef struct{
- tml_head_t h;
- }tml_reshape_t;
- typedef struct{
- tml_head_t h;
- uint8_t kernel_w;
- uint8_t kernel_h;
- uint8_t stride_w;
- uint8_t stride_h;
-
- uint8_t dilation_w;
- uint8_t dilation_h;
- uint16_t act; //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
-
- uint8_t pad[4]; //top,bottom,left,right
-
- uint32_t ws_oft; //weight scale oft from this layer start
- //skip bias scale: bias_scale = weight_scale*in_scale
- uint32_t w_oft; //weight oft from this layer start
- uint32_t b_oft; //bias oft from this layer start
- //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
- // fused in advance (when convert model)
- }tml_dwconv2d_t;
- typedef struct{
- tml_head_t h;
- uint32_t in_oft1;
- sctype_t in_s1; //input scale,
- zptype_t in_zp1; //input zeropoint
- uint32_t reserve; //align8
- }tml_add_t;
- /******************************* TYPE ************************************/
- typedef tm_err_t (*tml_stat_t)(tml_head_t* layer, tm_mat_t* in, tm_mat_t* out);
- typedef tm_err_t (*tm_cb_t)(tm_mdl_t* mdl, tml_head_t* lh);
- /******************************* GLOBAL VARIABLE ************************************/
- /******************************* MODEL FUNCTION ************************************/
- tm_err_t tm_load (tm_mdl_t* mdl, const uint8_t* bin, uint8_t*buf, tm_cb_t cb, tm_mat_t* in); //load model
- void tm_unload(tm_mdl_t* mdl); //remove model
- tm_err_t tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out); //preprocess input data
- tm_err_t tm_run (tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out); //run model
- /******************************* LAYER FUNCTION ************************************/
- tm_err_t tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
- int kw, int kh, int sx, int sy, int dx, int dy, int act, \
- int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
- sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
- tm_err_t tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
- tm_err_t tml_fc(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
- sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
- tm_err_t tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
- tm_err_t tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
- tm_err_t tml_add(tm_mat_t* in0, tm_mat_t* in1, tm_mat_t* out, \
- sctype_t in_s0, zptype_t in_zp0, sctype_t in_s1, zptype_t in_zp1, sctype_t out_s, zptype_t out_zp);
- /******************************* STAT FUNCTION ************************************/
- #if TM_ENABLE_STAT
- tm_err_t tm_stat(tm_mdlbin_t* mdl); //stat model
- #endif
- /******************************* UTILS FUNCTION ************************************/
- uint8_t TM_WEAK tm_fp32to8(float fp32);
- float TM_WEAK tm_fp8to32(uint8_t fp8);
- /******************************* UTILS ************************************/
- #define TML_GET_INPUT(mdl,lh) ((mtype_t*)((mdl)->buf + (lh)->in_oft))
- #define TML_GET_OUTPUT(mdl,lh) ((mtype_t*)((mdl)->buf + (lh)->out_oft))
- #if (TM_MDL_TYPE == TM_MDL_INT8)||(TM_MDL_TYPE == TM_MDL_INT16)
- #define TML_DEQUANT(lh, x) (((sumtype_t)(x)-((lh)->out_zp))*((lh)->out_s))
- #define TM_DEQUANT(i8,s,zp) (((sumtype_t)(i8)-(zp))*(s))
- #define TM_QUANT(fp32,s,zp) ((mtype_t)((fp32)/(s)+zp))
- #elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
- #define TML_DEQUANT(lh, x) (tm_fp8to32(x))
- #else //FP32,FP16
- #define TML_DEQUANT(lh, x) ((float)(x))
- #define TM_DEQUANT(x,s,zp) (x)
- #define TM_QUANT(x,s,zp) (x)
- #endif
- /******************************* LOCAL MATH FUNCTION ************************************/
- #if TM_LOCAL_MATH
- //http://www.machinedlearnings.com/2011/06/fast-approximate-logarithm-exponential.html
- static inline float _exp(float x) {
- float p = 1.442695040f * x;
- uint32_t i = 0;
- uint32_t sign = (i >> 31);
- int w = (int) p;
- float z = p - (float) w + (float) sign;
- union {
- uint32_t i;
- float f;
- } v = {.i = (uint32_t) ((1 << 23) * (p + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z))};
- return v.f;
- }
- #define tm_exp _exp //maybe some arch have exp acceleration, use macro in arch_xxx.h to reload it
- #else
- #define tm_exp exp
- #endif
- #endif
|