cpu_ops_body.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /* Copyright 2019-2020 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #pragma once
  16. #include "../node_body.h"
  17. namespace nncase
  18. {
  19. namespace runtime
  20. {
  21. namespace cpu
  22. {
  23. struct cpu_conv2d_options
  24. {
  25. memory_range input;
  26. memory_range output;
  27. runtime_shape_t in_shape;
  28. int32_t out_channels;
  29. padding padding_h;
  30. padding padding_w;
  31. int32_t filter_h;
  32. int32_t filter_w;
  33. int32_t stride_h;
  34. int32_t stride_w;
  35. int32_t dilation_h;
  36. int32_t dilation_w;
  37. value_range<float> fused_activation;
  38. xtl::span<const float> weights;
  39. xtl::span<const float> bias;
  40. void deserialize(span_reader &reader)
  41. {
  42. reader.read(input);
  43. reader.read(output);
  44. reader.read(in_shape);
  45. reader.read(out_channels);
  46. reader.read(padding_h);
  47. reader.read(padding_w);
  48. reader.read(filter_h);
  49. reader.read(filter_w);
  50. reader.read(stride_h);
  51. reader.read(stride_w);
  52. reader.read(dilation_h);
  53. reader.read(dilation_w);
  54. reader.read(fused_activation);
  55. reader.read_span(weights, (size_t)out_channels * in_shape[3] * filter_h * filter_w);
  56. reader.read_span(bias, out_channels);
  57. }
  58. };
  59. struct cpu_depthwise_conv2d_options
  60. {
  61. memory_range input;
  62. memory_range output;
  63. runtime_shape_t in_shape;
  64. padding padding_h;
  65. padding padding_w;
  66. int32_t filter_h;
  67. int32_t filter_w;
  68. int32_t stride_h;
  69. int32_t stride_w;
  70. int32_t dilation_h;
  71. int32_t dilation_w;
  72. value_range<float> fused_activation;
  73. xtl::span<const float> weights;
  74. xtl::span<const float> bias;
  75. void deserialize(span_reader &reader)
  76. {
  77. reader.read(input);
  78. reader.read(output);
  79. reader.read(in_shape);
  80. reader.read(padding_h);
  81. reader.read(padding_w);
  82. reader.read(filter_h);
  83. reader.read(filter_w);
  84. reader.read(stride_h);
  85. reader.read(stride_w);
  86. reader.read(dilation_h);
  87. reader.read(dilation_w);
  88. reader.read(fused_activation);
  89. reader.read_span(weights, (size_t)in_shape[3] * filter_h * filter_w);
  90. reader.read_span(bias, in_shape[3]);
  91. }
  92. };
  93. struct cpu_reduce_window2d_options : simple_node_body<cpu_reduce_window2d_options>
  94. {
  95. memory_range input;
  96. memory_range output;
  97. reduce_op_t reduce_op;
  98. runtime_shape_t in_shape;
  99. padding padding_h;
  100. padding padding_w;
  101. int32_t filter_h;
  102. int32_t filter_w;
  103. int32_t stride_h;
  104. int32_t stride_w;
  105. int32_t dilation_h;
  106. int32_t dilation_w;
  107. float init_value;
  108. value_range<float> fused_activation;
  109. };
  110. struct cpu_quantized_conv2d_options
  111. {
  112. memory_range input;
  113. memory_range output;
  114. runtime_shape_t in_shape;
  115. int32_t out_channels;
  116. padding padding_h;
  117. padding padding_w;
  118. int32_t filter_h;
  119. int32_t filter_w;
  120. int32_t stride_h;
  121. int32_t stride_w;
  122. int32_t dilation_h;
  123. int32_t dilation_w;
  124. int32_t input_offset;
  125. int32_t filter_offset;
  126. int32_t output_mul;
  127. int32_t output_shift;
  128. int32_t output_offset;
  129. xtl::span<const uint8_t> weights;
  130. xtl::span<const int32_t> bias;
  131. void deserialize(span_reader &reader)
  132. {
  133. reader.read(input);
  134. reader.read(output);
  135. reader.read(in_shape);
  136. reader.read(out_channels);
  137. reader.read(padding_h);
  138. reader.read(padding_w);
  139. reader.read(filter_h);
  140. reader.read(filter_w);
  141. reader.read(stride_h);
  142. reader.read(stride_w);
  143. reader.read(dilation_h);
  144. reader.read(dilation_w);
  145. reader.read(input_offset);
  146. reader.read(filter_offset);
  147. reader.read(output_mul);
  148. reader.read(output_shift);
  149. reader.read(output_offset);
  150. reader.read_span(weights, (size_t)out_channels * in_shape[3] * filter_h * filter_w);
  151. reader.read_span(bias, out_channels);
  152. }
  153. };
  154. struct cpu_quantized_depthwise_conv2d_options
  155. {
  156. memory_range input;
  157. memory_range output;
  158. runtime_shape_t in_shape;
  159. padding padding_h;
  160. padding padding_w;
  161. int32_t filter_h;
  162. int32_t filter_w;
  163. int32_t stride_h;
  164. int32_t stride_w;
  165. int32_t dilation_h;
  166. int32_t dilation_w;
  167. int32_t input_offset;
  168. int32_t filter_offset;
  169. int32_t output_mul;
  170. int32_t output_shift;
  171. int32_t output_offset;
  172. xtl::span<const uint8_t> weights;
  173. xtl::span<const int32_t> bias;
  174. void deserialize(span_reader &reader)
  175. {
  176. reader.read(input);
  177. reader.read(output);
  178. reader.read(in_shape);
  179. reader.read(padding_h);
  180. reader.read(padding_w);
  181. reader.read(filter_h);
  182. reader.read(filter_w);
  183. reader.read(stride_h);
  184. reader.read(stride_w);
  185. reader.read(dilation_h);
  186. reader.read(dilation_w);
  187. reader.read(input_offset);
  188. reader.read(filter_offset);
  189. reader.read(output_mul);
  190. reader.read(output_shift);
  191. reader.read(output_offset);
  192. reader.read_span(weights, (size_t)in_shape[3] * filter_h * filter_w);
  193. reader.read_span(bias, in_shape[3]);
  194. }
  195. };
  196. }
  197. }
  198. }