model_extractor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. #!/usr/bin/env python3
  2. #
  3. # SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  4. #
  5. # SPDX-License-Identifier: Apache-2.0
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the License); you may
  8. # not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an AS IS BASIS, WITHOUT
  15. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. #
  19. import os
  20. import sys
  21. import json
  22. import argparse
  23. import subprocess
  24. import numpy as np
  25. import tensorflow as tf
  26. from conv_settings import ConvSettings
  27. from softmax_settings import SoftmaxSettings
  28. from fully_connected_settings import FullyConnectedSettings
  29. class MODEL_EXTRACTOR(SoftmaxSettings, FullyConnectedSettings, ConvSettings):
  30. def __init__(self, dataset, schema_file, tflite_model):
  31. super().__init__(dataset, None, True, True, True, schema_file)
  32. self.tflite_model = tflite_model
  33. (self.quantized_multiplier, self.quantized_shift) = 0, 0
  34. self.is_int16xint8 = False # Only 8-bit supported.
  35. self.diff_min, self.input_multiplier, self.input_left_shift = 0, 0, 0
  36. self.supported_ops = ["CONV_2D", "DEPTHWISE_CONV_2D", "FULLY_CONNECTED", "AVERAGE_POOL_2D", "SOFTMAX"]
  37. def from_bytes(self, tensor_data, type_size) -> list:
  38. result = []
  39. tmp_ints = []
  40. if not (type_size == 1 or type_size == 2 or type_size == 4):
  41. raise RuntimeError("Size not supported: {}".format(type_size))
  42. count = 0
  43. for val in tensor_data:
  44. tmp_ints.append(val)
  45. count = count + 1
  46. if count % type_size == 0:
  47. tmp_bytes = bytearray(tmp_ints)
  48. result.append(int.from_bytes(tmp_bytes, 'little', signed=True))
  49. tmp_ints.clear()
  50. return result
  51. def tflite_to_json(self, tflite_input, schema):
  52. name_without_ext, ext = os.path.splitext(tflite_input)
  53. new_name = name_without_ext + '.json'
  54. dirname = os.path.dirname(tflite_input)
  55. if schema is None:
  56. raise RuntimeError("A schema file is required.")
  57. command = f"flatc -o {dirname} --strict-json -t {schema} -- {tflite_input}"
  58. command_list = command.split(' ')
  59. try:
  60. process = subprocess.run(command_list)
  61. if process.returncode != 0:
  62. print(f"ERROR: {command = }")
  63. sys.exit(1)
  64. except Exception as e:
  65. raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
  66. return new_name
  67. def write_c_config_header(self, name_prefix, op_name, op_index) -> None:
  68. filename = f"{name_prefix}_config_data.h"
  69. self.generated_header_files.append(filename)
  70. filepath = self.headers_dir + filename
  71. prefix = f'{op_name}_{op_index}'
  72. print("Writing C header with config data {}...".format(filepath))
  73. with open(filepath, "w+") as f:
  74. self.write_c_common_header(f)
  75. f.write("#define {}_OUT_CH {}\n".format(prefix, self.output_ch))
  76. f.write("#define {}_IN_CH {}\n".format(prefix, self.input_ch))
  77. f.write("#define {}_INPUT_W {}\n".format(prefix, self.x_input))
  78. f.write("#define {}_INPUT_H {}\n".format(prefix, self.y_input))
  79. f.write("#define {}_DST_SIZE {}\n".format(prefix,
  80. self.x_output * self.y_output * self.output_ch * self.batches))
  81. if op_name == "SOFTMAX":
  82. f.write("#define {}_NUM_ROWS {}\n".format(prefix, self.y_input))
  83. f.write("#define {}_ROW_SIZE {}\n".format(prefix, self.x_input))
  84. f.write("#define {}_MULT {}\n".format(prefix, self.input_multiplier))
  85. f.write("#define {}_SHIFT {}\n".format(prefix, self.input_left_shift))
  86. if not self.is_int16xint8:
  87. f.write("#define {}_DIFF_MIN {}\n".format(prefix, -self.diff_min))
  88. else:
  89. f.write("#define {}_FILTER_X {}\n".format(prefix, self.filter_x))
  90. f.write("#define {}_FILTER_Y {}\n".format(prefix, self.filter_y))
  91. f.write("#define {}_FILTER_W {}\n".format(prefix, self.filter_x))
  92. f.write("#define {}_FILTER_H {}\n".format(prefix, self.filter_y))
  93. f.write("#define {}_STRIDE_X {}\n".format(prefix, self.stride_x))
  94. f.write("#define {}_STRIDE_Y {}\n".format(prefix, self.stride_y))
  95. f.write("#define {}_STRIDE_W {}\n".format(prefix, self.stride_x))
  96. f.write("#define {}_STRIDE_H {}\n".format(prefix, self.stride_y))
  97. f.write("#define {}_PAD_X {}\n".format(prefix, self.pad_x))
  98. f.write("#define {}_PAD_Y {}\n".format(prefix, self.pad_y))
  99. f.write("#define {}_PAD_W {}\n".format(prefix, self.pad_x))
  100. f.write("#define {}_PAD_H {}\n".format(prefix, self.pad_y))
  101. f.write("#define {}_OUTPUT_W {}\n".format(prefix, self.x_output))
  102. f.write("#define {}_OUTPUT_H {}\n".format(prefix, self.y_output))
  103. f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, -self.input_zero_point))
  104. f.write("#define {}_INPUT_SIZE {}\n".format(prefix, self.x_input * self.y_input * self.input_ch))
  105. f.write("#define {}_OUT_ACTIVATION_MIN {}\n".format(prefix, self.out_activation_min))
  106. f.write("#define {}_OUT_ACTIVATION_MAX {}\n".format(prefix, self.out_activation_max))
  107. f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
  108. f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point))
  109. f.write("#define {}_DILATION_X {}\n".format(prefix, self.dilation_x))
  110. f.write("#define {}_DILATION_Y {}\n".format(prefix, self.dilation_y))
  111. f.write("#define {}_DILATION_W {}\n".format(prefix, self.dilation_x))
  112. f.write("#define {}_DILATION_H {}\n".format(prefix, self.dilation_y))
  113. if op_name == "FULLY_CONNECTED":
  114. f.write("#define {}_OUTPUT_MULTIPLIER {}\n".format(prefix, self.quantized_multiplier))
  115. f.write("#define {}_OUTPUT_SHIFT {}\n".format(prefix, self.quantized_shift))
  116. if op_name == "DEPTHWISE_CONV_2D":
  117. f.write("#define {}_ACCUMULATION_DEPTH {}\n".format(prefix,
  118. self.input_ch * self.x_input * self.y_input))
  119. self.format_output_file(filepath)
  120. def shape_to_config(self, input_shape, filter_shape, output_shape, layer_name):
  121. if layer_name == "AVERAGE_POOL_2D":
  122. [_, self.filter_y, self.filter_x, _] = input_shape
  123. elif layer_name == "CONV_2D" or layer_name == "DEPTHWISE_CONV_2D":
  124. [self.batches, self.y_input, self.x_input, self.input_ch] = input_shape
  125. [output_ch, self.filter_y, self.filter_x, self.input_ch] = filter_shape
  126. elif layer_name == "FULLY_CONNECTED":
  127. [self.batches, self.input_ch] = input_shape
  128. [self.input_ch, self.output_ch] = filter_shape
  129. [self.y_output, self.x_output] = output_shape
  130. self.x_input = 1
  131. self.y_input = 1
  132. elif layer_name == "SOFTMAX":
  133. [self.y_input, self.x_input] = input_shape
  134. if len(input_shape) == 4:
  135. if len(output_shape) == 2:
  136. [self.y_output, self.x_output] = output_shape
  137. else:
  138. [d, self.y_output, self.x_output, d1] = output_shape
  139. self.calculate_padding(self.x_output, self.y_output, self.x_input, self.y_input)
  140. def extract_from_model(self, json_file, tensor_details):
  141. with open(json_file, 'r') as in_file:
  142. data = in_file.read()
  143. data = json.loads(data)
  144. tensors = data['subgraphs'][0]['tensors']
  145. operators = data['subgraphs'][0]['operators']
  146. operator_codes = data['operator_codes']
  147. buffers = data['buffers']
  148. op_index = 0
  149. for op in operators:
  150. if 'opcode_index' in op:
  151. builtin_name = operator_codes[op['opcode_index']]['builtin_code']
  152. else:
  153. builtin_name = ""
  154. # Get stride and padding.
  155. if 'builtin_options' in op:
  156. builtin_options = op['builtin_options']
  157. if 'stride_w' in builtin_options:
  158. self.stride_x = builtin_options['stride_w']
  159. if 'stride_h' in builtin_options:
  160. self.stride_y = builtin_options['stride_h']
  161. if 'padding' in builtin_options:
  162. self.has_padding = False
  163. self.padding = 'VALID'
  164. else:
  165. self.has_padding = True
  166. self.padding = 'SAME'
  167. # Generate weights, bias, multipliers, shifts and config.
  168. if builtin_name not in self.supported_ops:
  169. print(f"WARNING: skipping unsupported operator {builtin_name}")
  170. else:
  171. input_index = op['inputs'][0]
  172. output_index = op['outputs'][0]
  173. input_tensor = tensor_details[input_index]
  174. output_tensor = tensor_details[output_index]
  175. input_scale = input_tensor['quantization'][0]
  176. output_scale = output_tensor['quantization'][0]
  177. self.input_zero_point = input_tensor['quantization'][1]
  178. self.output_zero_point = output_tensor['quantization'][1]
  179. input_shape = input_tensor['shape']
  180. output_shape = output_tensor['shape']
  181. if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
  182. or builtin_name == "FULLY_CONNECTED":
  183. weights_index = op['inputs'][1]
  184. bias_index = op['inputs'][2]
  185. weight_tensor = tensor_details[weights_index]
  186. scaling_factors = weight_tensor['quantization_parameters']['scales'].tolist()
  187. bias = tensors[bias_index]
  188. weights = tensors[weights_index]
  189. weights_data_index = weights['buffer']
  190. weights_data_buffer = buffers[weights_data_index]
  191. weights_data = self.from_bytes(weights_data_buffer['data'], 1)
  192. bias_data_index = bias['buffer']
  193. bias_data_buffer = buffers[bias_data_index]
  194. bias_data = self.from_bytes(bias_data_buffer['data'], 4)
  195. self.output_ch = len(scaling_factors)
  196. filter_shape = weights['shape']
  197. else:
  198. filter_shape = []
  199. self.input_scale, self.output_scale = input_scale, output_scale
  200. if builtin_name == "SOFTMAX":
  201. self.calc_softmax_params()
  202. self.shape_to_config(input_shape, filter_shape, output_shape, builtin_name)
  203. nice_name = 'layer_' + str(op_index) + '_' + builtin_name.lower()
  204. if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
  205. or builtin_name == "FULLY_CONNECTED":
  206. self.generate_c_array(nice_name + "_weights", weights_data)
  207. self.generate_c_array(nice_name + "_bias", bias_data, datatype='int32_t')
  208. if builtin_name == "FULLY_CONNECTED":
  209. self.weights_scale = scaling_factors[0]
  210. self.quantize_multiplier()
  211. elif builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D":
  212. self.scaling_factors = scaling_factors
  213. per_channel_multiplier, per_channel_shift = self.generate_quantize_per_channel_multiplier()
  214. self.generate_c_array(f"{nice_name}_output_mult", per_channel_multiplier, datatype='int32_t')
  215. self.generate_c_array(f"{nice_name}_output_shift", per_channel_shift, datatype='int32_t')
  216. self.write_c_config_header(nice_name, builtin_name, op_index)
  217. op_index = op_index + 1
  218. def generate_data(self, input_data=None, weights=None, biases=None) -> None:
  219. interpreter = self.Interpreter(model_path=str(self.tflite_model),
  220. experimental_op_resolver_type=self.OpResolverType.BUILTIN_REF)
  221. interpreter.allocate_tensors()
  222. # Needed for input/output scale/zp as equivalant json file data has too low precision.
  223. tensor_details = interpreter.get_tensor_details()
  224. output_details = interpreter.get_output_details()
  225. (self.output_scale, self.output_zero_point) = output_details[0]['quantization']
  226. input_details = interpreter.get_input_details()
  227. if len(input_details) != 1:
  228. raise RuntimeError("Only single input supported.")
  229. input_shape = input_details[0]['shape']
  230. input_data = self.get_randomized_input_data(input_data, input_shape)
  231. interpreter.set_tensor(input_details[0]["index"], tf.cast(input_data, tf.int8))
  232. self.generate_c_array("input", input_data)
  233. json_file = self.tflite_to_json(self.tflite_model, self.schema_file)
  234. self.extract_from_model(json_file, tensor_details)
  235. interpreter.invoke()
  236. output_data = interpreter.get_tensor(output_details[0]["index"])
  237. self.generate_c_array("output_ref", np.clip(output_data, self.out_activation_min, self.out_activation_max))
  238. self.write_c_header_wrapper()
  239. if __name__ == '__main__':
  240. parser = argparse.ArgumentParser(description="Extract operator data from given model if operator is supported."
  241. "This provides a way for CMSIS-NN to directly process a model.")
  242. parser.add_argument('--schema-file', type=str, required=True, help="Path to schema file.")
  243. parser.add_argument('--tflite-model', type=str, required=True, help="Path to tflite file.")
  244. parser.add_argument('--model-name',
  245. type=str,
  246. help="Descriptive model name. If left out it will be inferred from actual model.")
  247. args = parser.parse_args()
  248. schema_file = args.schema_file
  249. tflite_model = args.tflite_model
  250. if args.model_name:
  251. dataset = args.model_name
  252. else:
  253. dataset, _ = os.path.splitext(os.path.basename(tflite_model))
  254. model_extractor = MODEL_EXTRACTOR(dataset, schema_file, tflite_model)
  255. model_extractor.generate_data()