quantized.py 1.0 KB

123456789101112131415161718192021222324252627282930
  1. # Copyright (C) 2019 Intel Corporation. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  3. import tensorflow as tf
  4. import numpy as np
  5. import pathlib
  6. model = tf.keras.Sequential([
  7. tf.keras.layers.InputLayer(input_shape=[5, 5, 1]),
  8. tf.keras.layers.AveragePooling2D(
  9. pool_size=(5, 5), strides=None, padding="valid", data_format=None)
  10. ])
  11. def representative_dataset():
  12. for _ in range(1000):
  13. data = np.random.randint(0, 25, (1, 5, 5, 1))
  14. yield [data.astype(np.float32)]
  15. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  16. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  17. converter.representative_dataset = representative_dataset
  18. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  19. converter.inference_input_type = tf.uint8 # or tf.int8
  20. converter.inference_output_type = tf.uint8 # or tf.int8
  21. tflite_model = converter.convert()
  22. tflite_models_dir = pathlib.Path("./")
  23. tflite_model_file = tflite_models_dir / "quantized_model.tflite"
  24. tflite_model_file.write_bytes(tflite_model)