lstm_settings.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  2. #
  3. # SPDX-License-Identifier: Apache-2.0
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the License); you may
  6. # not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an AS IS BASIS, WITHOUT
  13. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. from test_settings import TestSettings
  19. import tensorflow as tf
  20. import numpy as np
  21. import tf_keras as keras
  22. class LSTMSettings(TestSettings):
  23. def __init__(self,
  24. dataset,
  25. testtype,
  26. regenerate_weights,
  27. regenerate_input,
  28. regenerate_biases,
  29. schema_file,
  30. batches=2,
  31. time_steps=2,
  32. number_inputs=3,
  33. number_units=4,
  34. time_major=True,
  35. randmin=TestSettings.INT8_MIN,
  36. randmax=TestSettings.INT8_MAX,
  37. generate_bias=True,
  38. interpreter="tensorflow"):
  39. super().__init__(dataset,
  40. testtype,
  41. regenerate_weights,
  42. regenerate_input,
  43. regenerate_biases,
  44. schema_file,
  45. 1,
  46. 1,
  47. 1,
  48. 1,
  49. 1,
  50. 1,
  51. 1,
  52. 1,
  53. False,
  54. randmin,
  55. randmax,
  56. generate_bias=generate_bias,
  57. interpreter=interpreter)
  58. self.batches = batches
  59. self.time_steps = time_steps
  60. self.number_units = number_units
  61. self.number_inputs = number_inputs
  62. self.kernel_hidden_table_file = self.pregenerated_data_dir + self.testdataset + '/' + 'kernel_hidden.txt'
  63. self.time_major = time_major
  64. self.in_activation_max = TestSettings.INT16_MAX
  65. self.in_activation_min = TestSettings.INT16_MIN
  66. self.lstm_scales = []
  67. # Layer indexes. Works with tensorflow 2.10 and 2.11.
  68. self.output_gate_bias_index = 1
  69. self.cell_gate_bias_index = 2
  70. self.forget_gate_bias_index = 3
  71. self.input_gate_bias_index = 4
  72. self.recurrent_input_to_output_w_index = 5
  73. self.recurrent_input_to_cell_w_index = 6
  74. self.recurrent_input_to_forget_w_index = 7
  75. self.recurrent_input_to_input_w_index = 8
  76. self.input_to_output_w_index = 9
  77. self.input_to_cell_w_index = 10
  78. self.input_to_forget_w_index = 11
  79. self.input_to_input_w_index = 12
  80. self.output_state_index = 13
  81. self.cell_state_index = 14
  82. self.input_norm_coeff_index = 15
  83. self.forget_norm_coeff_index = 16
  84. self.cell_norm_coeff_index = 17
  85. self.output_norm_coeff_index = 18
  86. self.effective_hidden_scale_intermediate_index = 20
  87. def generate_data(self, input_data=None, weights=None, hidden_weights=None, biases=None) -> None:
  88. input_dims = [self.batches, self.time_steps, self.number_inputs]
  89. if input_data is not None:
  90. input_data = tf.reshape(input_data, input_dims)
  91. else:
  92. input_data = self.get_randomized_data(input_dims,
  93. self.inputs_table_file,
  94. regenerate=self.regenerate_new_input)
  95. # This will be the same size when there is no projection.
  96. number_cells = self.number_units
  97. # Each LSTM cell has 4 input weights, 4 hidden (recurrent or cell state) weights and 4 biases.
  98. number_w_b = 4
  99. if weights is not None:
  100. weights = tf.reshape(weights, [self.number_inputs, number_cells * number_w_b])
  101. else:
  102. weights = self.get_randomized_data([self.number_inputs, number_cells * number_w_b],
  103. self.kernel_table_file,
  104. regenerate=self.regenerate_new_weights,
  105. decimals=8,
  106. minrange=-1.0,
  107. maxrange=1.0)
  108. if hidden_weights is not None:
  109. hidden_weights = tf.reshape(hidden_weights, [number_cells, number_cells * number_w_b])
  110. else:
  111. hidden_weights = self.get_randomized_data([number_cells, number_cells * number_w_b],
  112. self.kernel_hidden_table_file,
  113. regenerate=self.regenerate_new_weights,
  114. decimals=8,
  115. minrange=-1.0,
  116. maxrange=1.0)
  117. if not self.generate_bias:
  118. biases = [0] * number_cells * number_w_b
  119. if biases is not None:
  120. biases = tf.reshape(biases, [number_cells * number_w_b])
  121. else:
  122. biases = self.get_randomized_data([number_cells * number_w_b],
  123. self.bias_table_file,
  124. regenerate=self.regenerate_new_bias,
  125. decimals=8,
  126. minrange=-1.0,
  127. maxrange=1.0)
  128. # Create a Keras based LSTM model.
  129. input_layer = keras.layers.Input(shape=(self.time_steps, self.number_inputs),
  130. batch_size=self.batches,
  131. name='input')
  132. if self.time_major:
  133. input_layer_transposed = tf.transpose(input_layer, perm=[1, 0, 2])
  134. lstm_layer = keras.layers.LSTM(units=self.number_units,
  135. time_major=self.time_major,
  136. return_sequences=True)(input_layer_transposed)
  137. else:
  138. lstm_layer = keras.layers.LSTM(units=self.number_units,
  139. time_major=self.time_major,
  140. return_sequences=True)(input_layer)
  141. model = keras.Model(input_layer, lstm_layer, name="LSTM")
  142. if self.time_major:
  143. time_major_offset = 1
  144. shape = (self.time_steps, self.batches, self.number_inputs)
  145. else:
  146. time_major_offset = 0
  147. shape = (self.batches, self.time_steps, self.number_inputs)
  148. # Writing weight and bias to model.
  149. print("Updating weights", model.layers[1 + time_major_offset].weights[0].name)
  150. model.layers[1 + time_major_offset].weights[0].assign(weights)
  151. print("Updating hidden weights", model.layers[1 + time_major_offset].weights[1].name)
  152. model.layers[1 + time_major_offset].weights[1].assign(hidden_weights)
  153. print("Updating bias", model.layers[1 + time_major_offset].weights[2].name)
  154. model.layers[1 + time_major_offset].weights[2].assign(biases)
  155. interpreter = self.convert_and_interpret(model, tf.int8, input_data, dataset_shape=shape)
  156. all_layers_details = interpreter.get_tensor_details()
  157. for i in all_layers_details:
  158. self.lstm_scales.append(i['quantization_parameters']['scales'])
  159. input_data_for_index = all_layers_details[0]
  160. input_gate_bias = all_layers_details[self.input_gate_bias_index + time_major_offset]
  161. forget_gate_bias = all_layers_details[self.forget_gate_bias_index + time_major_offset]
  162. cell_gate_bias = all_layers_details[self.cell_gate_bias_index + time_major_offset]
  163. output_gate_bias = all_layers_details[self.output_gate_bias_index + time_major_offset]
  164. input_to_input_w = all_layers_details[self.input_to_input_w_index + time_major_offset]
  165. input_to_forget_w = all_layers_details[self.input_to_forget_w_index + time_major_offset]
  166. input_to_cell_w = all_layers_details[self.input_to_cell_w_index + time_major_offset]
  167. input_to_output_w = all_layers_details[self.input_to_output_w_index + time_major_offset]
  168. recurrent_input_to_input_w = all_layers_details[self.recurrent_input_to_input_w_index + time_major_offset]
  169. recurrent_input_to_forget_w = all_layers_details[self.recurrent_input_to_forget_w_index + time_major_offset]
  170. recurrent_input_to_cell_w = all_layers_details[self.recurrent_input_to_cell_w_index + time_major_offset]
  171. recurrent_input_to_output_w = all_layers_details[self.recurrent_input_to_output_w_index + time_major_offset]
  172. if self.time_major:
  173. time_major_offset = 2
  174. output_state = all_layers_details[self.output_state_index + time_major_offset]
  175. cell_state = all_layers_details[self.cell_state_index + time_major_offset]
  176. input_norm_coeff = all_layers_details[self.input_norm_coeff_index + time_major_offset]
  177. forget_norm_coeff = all_layers_details[self.forget_norm_coeff_index + time_major_offset]
  178. cell_norm_coeff = all_layers_details[self.cell_norm_coeff_index + time_major_offset]
  179. output_norm_coeff = all_layers_details[self.output_norm_coeff_index + time_major_offset]
  180. # For scale and zero point.
  181. effective_hidden_scale_intermediate = all_layers_details[
  182. self.effective_hidden_scale_intermediate_index + time_major_offset]
  183. input_details = interpreter.get_input_details()
  184. output_details = interpreter.get_output_details()
  185. actual_input_data = interpreter.get_tensor(input_details[0]["index"])
  186. if (input_data.numpy().shape != actual_input_data.shape) or \
  187. not ((input_data.numpy().astype(int) == actual_input_data).all().astype(int)):
  188. raise RuntimeError("Input data mismatch")
  189. self.generate_c_array(self.input_data_file_prefix, interpreter.get_tensor(input_data_for_index['index']))
  190. self.generate_c_array("input_to_input_w", interpreter.get_tensor(input_to_input_w['index']))
  191. self.generate_c_array("input_to_forget_w", interpreter.get_tensor(input_to_forget_w['index']))
  192. self.generate_c_array("input_to_cell_w", interpreter.get_tensor(input_to_cell_w['index']))
  193. self.generate_c_array("input_to_output_w", interpreter.get_tensor(input_to_output_w['index']))
  194. self.generate_c_array("recurrent_input_to_input_w", interpreter.get_tensor(recurrent_input_to_input_w['index']))
  195. self.generate_c_array("recurrent_input_to_forget_w",
  196. interpreter.get_tensor(recurrent_input_to_forget_w['index']))
  197. self.generate_c_array("recurrent_input_to_cell_w", interpreter.get_tensor(recurrent_input_to_cell_w['index']))
  198. self.generate_c_array("recurrent_input_to_output_w",
  199. interpreter.get_tensor(recurrent_input_to_output_w['index']))
  200. # Peephole not supported so these are nullptrs.
  201. self.generate_c_array("cell_to_input", [], datatype='int16_t')
  202. self.generate_c_array("cell_to_forget", [], datatype='int16_t')
  203. self.generate_c_array("cell_to_output", [], datatype='int16_t')
  204. self.generate_c_array("input_gate_bias", interpreter.get_tensor(input_gate_bias['index']), datatype='int32_t')
  205. self.generate_c_array("cell_gate_bias", interpreter.get_tensor(cell_gate_bias['index']), datatype='int32_t')
  206. self.generate_c_array("forget_gate_bias", interpreter.get_tensor(forget_gate_bias['index']), datatype='int32_t')
  207. self.generate_c_array("output_gate_bias", interpreter.get_tensor(output_gate_bias['index']), datatype='int32_t')
  208. # Projection not supported so these are nullptrs.
  209. self.generate_c_array("projection_weights", [])
  210. self.generate_c_array("projection_bias", [], datatype='int32_t')
  211. self.generate_c_array("output_state", interpreter.get_tensor(output_state['index']), const="")
  212. self.generate_c_array("cell_state", interpreter.get_tensor(cell_state['index']), datatype='int16_t', const="")
  213. self.generate_c_array("input_norm_coeff", interpreter.get_tensor(input_norm_coeff['index']))
  214. self.generate_c_array("forget_norm_coeff", interpreter.get_tensor(forget_norm_coeff['index']))
  215. self.generate_c_array("cell_norm_coeff", interpreter.get_tensor(cell_norm_coeff['index']))
  216. self.generate_c_array("output_norm_coeff", interpreter.get_tensor(output_norm_coeff['index']))
  217. input_scale = input_data_for_index['quantization_parameters']['scales'][0]
  218. self.data_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
  219. cell_scale = cell_state['quantization_parameters']['scales'][0]
  220. output_state_scale = output_state['quantization_parameters']['scales'][0]
  221. input_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
  222. output_zp = output_details[0]['quantization_parameters']['zero_points'][0]
  223. output_state_zp = output_state['quantization_parameters']['zero_points'][0]
  224. self.hidden_zp = effective_hidden_scale_intermediate['quantization_parameters']['zero_points'][0]
  225. self.output_state_offset = output_state_zp
  226. tmp = math.log(cell_scale) * (1 / math.log(2))
  227. self.cell_state_shift = int(round(tmp))
  228. self.calc_scales(input_scale, output_state_scale, cell_scale)
  229. # Calculate effective biases.
  230. input_zp = -input_zp
  231. output_zp = -output_zp
  232. output_state_zp = -output_state_zp
  233. input_to_forget_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_forget_w, forget_gate_bias)
  234. recurrent_to_forget_eff_bias = self.calc_effective_bias(interpreter, output_state_zp,
  235. recurrent_input_to_forget_w, None, False)
  236. input_to_cell_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_cell_w, cell_gate_bias)
  237. recurrent_to_cell_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_cell_w,
  238. None, False)
  239. input_to_output_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_output_w, output_gate_bias)
  240. recurrent_to_output_eff_bias = self.calc_effective_bias(interpreter, output_state_zp,
  241. recurrent_input_to_output_w, None, False)
  242. input_to_input_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_input_w, input_gate_bias)
  243. recurrent_to_input_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_input_w,
  244. None, False)
  245. self.generate_c_array("input_to_input_eff_bias", input_to_input_eff_bias, datatype='int32_t')
  246. self.generate_c_array("input_to_forget_eff_bias", input_to_forget_eff_bias, datatype='int32_t')
  247. self.generate_c_array("input_to_cell_eff_bias", input_to_cell_eff_bias, datatype='int32_t')
  248. self.generate_c_array("input_to_output_eff_bias", input_to_output_eff_bias, datatype='int32_t')
  249. self.generate_c_array("recurrent_to_input_eff_bias", recurrent_to_input_eff_bias, datatype='int32_t')
  250. self.generate_c_array("recurrent_to_cell_eff_bias", recurrent_to_cell_eff_bias, datatype='int32_t')
  251. self.generate_c_array("recurrent_to_forget_eff_bias", recurrent_to_forget_eff_bias, datatype='int32_t')
  252. self.generate_c_array("recurrent_to_output_eff_bias", recurrent_to_output_eff_bias, datatype='int32_t')
  253. # Generate reference
  254. if self.use_tflite_micro_interpreter:
  255. interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite))
  256. interpreter.set_input(tf.cast(input_data, tf.int8), input_details[0]["index"])
  257. interpreter.invoke()
  258. output_data = interpreter.get_output(0)
  259. else:
  260. interpreter.invoke()
  261. output_data = interpreter.get_tensor(output_details[0]["index"])
  262. self.generate_c_array(self.output_data_file_prefix, output_data, datatype='int8_t')
  263. self.write_c_config_header()
  264. self.write_c_header_wrapper()
  265. def calc_scales(self, input_scale, output_state_scale, cell_scale):
  266. intermediate_scale = pow(2, -12)
  267. if self.time_major:
  268. time_major_offset = 1
  269. else:
  270. time_major_offset = 0
  271. self.effective_forget_scale = pow(2, -15) / cell_scale * cell_scale
  272. self.effective_input_scale = pow(2, -15) / cell_scale * pow(2, -15)
  273. self.effective_hidden_scale = pow(2, -15) / output_state_scale * pow(2, -15)
  274. self.i2i_effective_scale = input_scale * self.lstm_scales[self.input_to_input_w_index + time_major_offset][0] \
  275. / intermediate_scale
  276. self.i2f_effective_scale = input_scale * self.lstm_scales[self.input_to_forget_w_index + time_major_offset][0] \
  277. / intermediate_scale
  278. self.i2c_effective_scale = input_scale * self.lstm_scales[self.input_to_cell_w_index + time_major_offset][0] \
  279. / intermediate_scale
  280. self.i2o_effective_scale = input_scale * self.lstm_scales[self.input_to_output_w_index + time_major_offset][0] \
  281. / intermediate_scale
  282. self.r2i_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_input_w_index +
  283. time_major_offset][0] / intermediate_scale
  284. self.r2f_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_forget_w_index +
  285. time_major_offset][0] / intermediate_scale
  286. self.r2c_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_cell_w_index +
  287. time_major_offset][0] / intermediate_scale
  288. self.r2o_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_output_w_index +
  289. time_major_offset][0] / intermediate_scale
  290. def calc_effective_bias(self, interpreter, zero_point, weight_tensor, bias_tensor, has_bias=True) -> list:
  291. weights = interpreter.get_tensor(weight_tensor['index'])
  292. dims = weight_tensor['shape']
  293. row = dims[0]
  294. col = dims[1]
  295. if has_bias:
  296. bias_data = interpreter.get_tensor(bias_tensor['index'])
  297. output = bias_data
  298. else:
  299. output = np.zeros((row, ), dtype=np.int32)
  300. for i_row in range(row):
  301. row_sum = 0
  302. for i_col in range(col):
  303. row_sum = row_sum + weights[i_row][i_col]
  304. output[i_row] = output[i_row] + row_sum * zero_point
  305. return output
  306. def write_c_config_header(self) -> None:
  307. super().write_c_config_header(write_common_parameters=False)
  308. filename = self.config_data
  309. filepath = self.headers_dir + filename
  310. prefix = self.testdataset.upper()
  311. with open(filepath, "a") as f:
  312. f.write("#define {}_BUFFER_SIZE {}\n".format(prefix, self.batches * self.number_units))
  313. f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
  314. f.write("#define {}_DST_SIZE {}\n".format(prefix, self.batches * self.time_steps * self.number_units))
  315. f.write("#define {}_TIME_STEPS {}\n".format(prefix, self.time_steps))
  316. f.write("#define {}_NUMBER_UNITS {}\n".format(prefix, self.number_units))
  317. f.write("#define {}_NUMBER_INPUTS {}\n".format(prefix, self.number_inputs))
  318. f.write("#define {}_TIME_MAJOR {}\n".format(prefix, int(self.time_major)))
  319. f.write("#define {}_IN_ACTIVATION_MIN {}\n".format(prefix, self.in_activation_min))
  320. f.write("#define {}_IN_ACTIVATION_MAX {}\n".format(prefix, self.in_activation_max))
  321. (multiplier, shift) = self.quantize_scale(self.i2i_effective_scale)
  322. f.write("#define {}_IN_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
  323. f.write("#define {}_IN_TO_INPUT_SHIFT {}\n".format(prefix, shift))
  324. (multiplier, shift) = self.quantize_scale(self.i2f_effective_scale)
  325. f.write("#define {}_IN_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
  326. f.write("#define {}_IN_TO_FORGET_SHIFT {}\n".format(prefix, shift))
  327. (multiplier, shift) = self.quantize_scale(self.i2c_effective_scale)
  328. f.write("#define {}_IN_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier))
  329. f.write("#define {}_IN_TO_CELL_SHIFT {}\n".format(prefix, shift))
  330. (multiplier, shift) = self.quantize_scale(self.i2o_effective_scale)
  331. f.write("#define {}_IN_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier))
  332. f.write("#define {}_IN_TO_OUTPUT_SHIFT {}\n".format(prefix, shift))
  333. (multiplier, shift) = self.quantize_scale(self.r2i_effective_scale)
  334. f.write("#define {}_RECURRENT_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
  335. f.write("#define {}_RECURRENT_TO_INPUT_SHIFT {}\n".format(prefix, shift))
  336. (multiplier, shift) = self.quantize_scale(self.r2f_effective_scale)
  337. f.write("#define {}_RECURRENT_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
  338. f.write("#define {}_RECURRENT_TO_FORGET_SHIFT {}\n".format(prefix, shift))
  339. (multiplier, shift) = self.quantize_scale(self.r2c_effective_scale)
  340. f.write("#define {}_RECURRENT_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier))
  341. f.write("#define {}_RECURRENT_TO_CELL_SHIFT {}\n".format(prefix, shift))
  342. (multiplier, shift) = self.quantize_scale(self.r2o_effective_scale)
  343. f.write("#define {}_RECURRENT_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier))
  344. f.write("#define {}_RECURRENT_TO_OUTPUT_SHIFT {}\n".format(prefix, shift))
  345. (multiplier, shift) = self.quantize_scale(self.effective_forget_scale)
  346. f.write("#define {}_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
  347. f.write("#define {}_FORGET_SHIFT {}\n".format(prefix, shift))
  348. (multiplier, shift) = self.quantize_scale(self.effective_input_scale)
  349. f.write("#define {}_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
  350. f.write("#define {}_INPUT_SHIFT {}\n".format(prefix, shift))
  351. (multiplier, shift) = self.quantize_scale(self.effective_hidden_scale)
  352. f.write("#define {}_HIDDEN_MULTIPLIER {}\n".format(prefix, multiplier))
  353. f.write("#define {}_HIDDEN_SHIFT {}\n".format(prefix, shift))
  354. f.write("#define {}_HIDDEN_OFFSET {}\n".format(prefix, self.hidden_zp))
  355. f.write("#define {}_DATA_OFFSET {}\n".format(prefix, -self.data_zp))
  356. f.write("#define {}_OUTPUT_STATE_OFFSET {}\n".format(prefix, self.output_state_offset))
  357. f.write("#define {}_CELL_STATE_SHIFT {}\n".format(prefix, self.cell_state_shift))
  358. for i in range(len(self.lstm_scales)):
  359. if len(self.lstm_scales[i]) == 0:
  360. continue
  361. (multiplier, shift) = self.quantize_scale(self.lstm_scales[i][0])