table_gen.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/python
  2. import math
  3. class Table(object):
  4. def __init__(self, table_entry=256, table_range=8):
  5. self.table_entry = table_entry
  6. self.table_range = table_range
  7. pass
  8. def sigmoid(self, x):
  9. return 1 / (1 + math.exp(-1*x))
  10. def tanh(self, x):
  11. return (math.exp(2*x)-1) / (math.exp(2*x)+1)
  12. def fp2q7(self, x):
  13. x_int = math.floor(x*(2**7)+0.5)
  14. if x_int >= 128 :
  15. x_int = 127
  16. if x_int < -128 :
  17. x_int = -128
  18. if x_int >= 0 :
  19. return x_int
  20. else :
  21. return 0x100 + x_int
  22. def fp2q15(self, x):
  23. x_int = math.floor(x*(2**15)+0.5)
  24. if x_int >= 2**15 :
  25. x_int = 2**15-1
  26. if x_int < -1*2**15 :
  27. x_int = -1*2**15
  28. if x_int >= 0 :
  29. return x_int
  30. else :
  31. return 0x10000 + x_int
  32. def table_gen(self):
  33. outfile = open("NNCommonTable.c", "wb")
  34. outfile.write("/*\n * Common tables for NN\n *\n *\n *\n *\n */\n\n#include \"arm_math.h\"\n#include \"NNCommonTable.h\"\n\n/*\n * Table for sigmoid\n */\n")
  35. for function_type in ["sigmoid", "tanh"]:
  36. for data_type in [7, 15]:
  37. out_type = "q"+str(data_type)+"_t"
  38. act_func = getattr(self, function_type)
  39. quan_func = getattr(self, 'fp2q'+str(data_type))
  40. # unified table
  41. outfile.write('const %s %sTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry) )
  42. for i in range(self.table_entry):
  43. # convert into actual value
  44. if i < self.table_entry/2:
  45. value_q7 = self.table_range * (i)
  46. else:
  47. value_q7 = self.table_range * (i - self.table_entry)
  48. if data_type == 7:
  49. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  50. outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
  51. else:
  52. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  53. outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
  54. if i % 8 == 7:
  55. outfile.write("\n")
  56. outfile.write("};\n\n")
  57. for data_type in [15]:
  58. out_type = "q"+str(data_type)+"_t"
  59. act_func = getattr(self, function_type)
  60. quan_func = getattr(self, 'fp2q'+str(data_type))
  61. # H-L tables
  62. outfile.write('const %s %sLTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry/2))
  63. for i in range(self.table_entry/2):
  64. # convert into actual value, max value is 16*self.table_entry/4 / 4
  65. # which is equivalent to self.table_entry / self.table_entry/2 = 2, i.e., 1/4 of 8
  66. if i < self.table_entry/4:
  67. value_q7 = self.table_range * i / 4
  68. else:
  69. value_q7 = self.table_range * (i - self.table_entry/2) / 4
  70. if data_type == 7:
  71. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  72. outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
  73. else:
  74. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  75. outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
  76. if i % 8 == 7:
  77. outfile.write("\n")
  78. outfile.write("};\n\n")
  79. outfile.write('const %s %sHTable_q%d[%d] = {\n' % (out_type, function_type, data_type, 3*self.table_entry/4))
  80. for i in range(3 * self.table_entry/4):
  81. # convert into actual value, tageting range (2, 8)
  82. if i < 3*self.table_entry/8 :
  83. value_q7 = self.table_range * ( i + self.table_entry/8 )
  84. else:
  85. value_q7 = self.table_range * ( i + self.table_entry/8 - self.table_entry)
  86. if data_type == 7:
  87. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  88. outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
  89. else:
  90. #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
  91. outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
  92. if i % 8 == 7:
  93. outfile.write("\n")
  94. outfile.write("};\n\n")
  95. outfile.close()
  96. mytable = Table(table_entry=256, table_range=16)
  97. mytable.table_gen()