pooling_settings.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  2. #
  3. # SPDX-License-Identifier: Apache-2.0
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the License); you may
  6. # not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an AS IS BASIS, WITHOUT
  13. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from test_settings import TestSettings
  18. import numpy as np
  19. import tensorflow as tf
  20. import tf_keras as keras
  21. class PoolingSettings(TestSettings):
  22. def __init__(self,
  23. dataset,
  24. testtype,
  25. regenerate_weights,
  26. regenerate_input,
  27. regenerate_biases,
  28. schema_file,
  29. channels=8,
  30. x_in=4,
  31. y_in=4,
  32. w_x=4,
  33. w_y=4,
  34. stride_x=1,
  35. stride_y=1,
  36. randmin=TestSettings.INT8_MIN,
  37. randmax=TestSettings.INT8_MAX,
  38. bias_min=TestSettings.INT32_MIN,
  39. bias_max=TestSettings.INT32_MAX,
  40. batches=1,
  41. pad=False,
  42. relu6=False,
  43. out_activation_min=None,
  44. out_activation_max=None,
  45. int16xint8=False,
  46. interpreter="tensorflow"):
  47. super().__init__(dataset,
  48. testtype,
  49. regenerate_weights,
  50. regenerate_input,
  51. regenerate_biases,
  52. schema_file,
  53. channels,
  54. channels,
  55. x_in,
  56. y_in,
  57. w_x,
  58. w_y,
  59. stride_x,
  60. stride_y,
  61. pad,
  62. randmin=randmin,
  63. randmax=randmax,
  64. batches=batches,
  65. relu6=relu6,
  66. out_activation_min=out_activation_min,
  67. out_activation_max=out_activation_max,
  68. int16xint8=int16xint8,
  69. interpreter=interpreter)
  70. def generate_data(self, input_data=None) -> None:
  71. if self.is_int16xint8:
  72. datatype = "int16_t"
  73. inttype = tf.int16
  74. else:
  75. datatype = "int8_t"
  76. inttype = tf.int8
  77. input_data = self.get_randomized_input_data(input_data)
  78. self.generate_c_array(self.input_data_file_prefix, input_data, datatype=datatype)
  79. input_data = tf.cast(input_data, tf.float32)
  80. # Create a one-layer Keras model
  81. model = keras.models.Sequential()
  82. input_shape = (self.batches, self.y_input, self.x_input, self.input_ch)
  83. model.add(keras.layers.InputLayer(input_shape=input_shape[1:], batch_size=self.batches))
  84. if self.test_type == 'avgpool':
  85. model.add(
  86. keras.layers.AveragePooling2D(pool_size=(self.filter_y, self.filter_x),
  87. strides=(self.stride_y, self.stride_x),
  88. padding=self.padding,
  89. input_shape=input_shape[1:]))
  90. elif self.test_type == 'maxpool':
  91. model.add(
  92. keras.layers.MaxPooling2D(pool_size=(self.filter_y, self.filter_x),
  93. strides=(self.stride_y, self.stride_x),
  94. padding=self.padding,
  95. input_shape=input_shape[1:]))
  96. else:
  97. raise RuntimeError("Wrong test type")
  98. interpreter = self.convert_and_interpret(model, inttype, input_data)
  99. output_details = interpreter.get_output_details()
  100. self.x_output = output_details[0]['shape'][2]
  101. self.y_output = output_details[0]['shape'][1]
  102. self.calculate_padding(self.x_output, self.y_output, self.x_input, self.y_input)
  103. # Generate reference
  104. interpreter.invoke()
  105. output_data = interpreter.get_tensor(output_details[0]["index"])
  106. self.generate_c_array(self.output_data_file_prefix,
  107. np.clip(output_data, self.out_activation_min, self.out_activation_max),
  108. datatype=datatype)
  109. self.write_c_config_header()
  110. self.write_c_header_wrapper()
  111. def write_c_config_header(self) -> None:
  112. super().write_c_config_header()
  113. filename = self.config_data
  114. filepath = self.headers_dir + filename
  115. prefix = self.testdataset.upper()
  116. with open(filepath, "a") as f:
  117. self.write_common_config(f, prefix)