neutral_ops.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. /* Copyright 2019-2020 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include <kernels/neutral/neutral_kernels.h>
  16. #include <runtime/kernel_registry.h>
  17. #include <runtime/neutral/neutral_ops_body.h>
  18. using namespace nncase;
  19. using namespace nncase::runtime;
  20. #define ELEM_SIZE_IMPL(type, KERNEL) \
  21. switch (runtime::get_bytes(type)) \
  22. { \
  23. case 1: \
  24. KERNEL(uint8_t); \
  25. break; \
  26. case 2: \
  27. KERNEL(uint16_t); \
  28. break; \
  29. case 4: \
  30. KERNEL(uint32_t); \
  31. break; \
  32. default: \
  33. return kcr_error; \
  34. }
  35. #define FP_OR_Q_IMPL(type, KERNEL) \
  36. switch (type) \
  37. { \
  38. case dt_float32: \
  39. KERNEL(float); \
  40. break; \
  41. case dt_uint8: \
  42. KERNEL(uint8_t); \
  43. break; \
  44. default: \
  45. return kcr_error; \
  46. }
  47. namespace nncase
  48. {
  49. namespace runtime
  50. {
  51. namespace neutral
  52. {
  53. kernel_call_result binary(binary_options &options, interpreter_t &interpreter, interpreter_step_t step)
  54. {
  55. auto input_a = interpreter.memory_at<float>(options.input_a);
  56. auto input_b = interpreter.memory_at<float>(options.input_b);
  57. auto output = interpreter.memory_at<float>(options.output);
  58. auto binary = [&](auto op) {
  59. kernels::neutral::binary(input_a.data(), input_b.data(), output.data(), options.in_a_shape, options.in_b_shape, options.out_shape, options.fused_activation, op);
  60. };
  61. switch (options.binary_op)
  62. {
  63. case binary_add:
  64. binary([](auto a, auto b) { return a + b; });
  65. return kcr_done;
  66. case binary_sub:
  67. binary([](auto a, auto b) { return a - b; });
  68. return kcr_done;
  69. case binary_mul:
  70. binary([](auto a, auto b) { return a * b; });
  71. return kcr_done;
  72. case binary_div:
  73. binary([](auto a, auto b) { return a / b; });
  74. return kcr_done;
  75. case binary_min:
  76. binary([](auto a, auto b) { return std::min(a, b); });
  77. return kcr_done;
  78. case binary_max:
  79. binary([](auto a, auto b) { return std::max(a, b); });
  80. return kcr_done;
  81. default:
  82. return kcr_error;
  83. }
  84. }
  85. kernel_call_result quantized_binary(quantized_binary_options &options, interpreter_t &interpreter, interpreter_step_t step)
  86. {
  87. auto input_a = interpreter.memory_at<uint8_t>(options.input_a);
  88. auto input_b = interpreter.memory_at<uint8_t>(options.input_b);
  89. auto output = interpreter.memory_at<uint8_t>(options.output);
  90. auto binary = [&](auto op) {
  91. kernels::neutral::quantized_binary(input_a.data(), input_b.data(), output.data(), options.in_a_shape, options.in_b_shape, options.out_shape,
  92. options.input_a_offset, options.input_a_mul, options.input_a_shift, options.input_b_offset, options.input_b_mul, options.input_b_shift,
  93. options.output_mul, options.output_shift, options.output_offset, op);
  94. };
  95. switch (options.binary_op)
  96. {
  97. case binary_add:
  98. binary([](auto a, auto b) { return a + b; });
  99. return kcr_done;
  100. case binary_sub:
  101. binary([](auto a, auto b) { return a - b; });
  102. return kcr_done;
  103. case binary_mul:
  104. binary([](auto a, auto b) { return a * b; });
  105. return kcr_done;
  106. case binary_div:
  107. binary([](auto a, auto b) { return (a + b / 2) / b; });
  108. return kcr_done;
  109. case binary_min:
  110. binary([](auto a, auto b) { return std::min(a, b); });
  111. return kcr_done;
  112. case binary_max:
  113. binary([](auto a, auto b) { return std::max(a, b); });
  114. return kcr_done;
  115. default:
  116. return kcr_error;
  117. }
  118. }
  119. kernel_call_result concat(concat_options &options, interpreter_t &interpreter, interpreter_step_t step)
  120. {
  121. auto output = interpreter.memory_at<uint8_t>(options.output);
  122. kernels::neutral::concat(options.inputs, output.data(), options.dims, options.inner_size, options.outer_size,
  123. [&](const memory_range &range) { return interpreter.memory_at<uint8_t>(range).data(); });
  124. return kcr_done;
  125. }
  126. kernel_call_result conv2d(conv2d_options &options, interpreter_t &interpreter, interpreter_step_t step)
  127. {
  128. auto input = interpreter.memory_at<float>(options.input);
  129. auto output = interpreter.memory_at<float>(options.output);
  130. kernels::neutral::conv2d(input.data(), output.data(), options.weights.data(), options.bias.data(), options.in_shape, options.groups, options.out_channels, options.filter_h,
  131. options.filter_w, options.stride_h, options.stride_w, options.dilation_h, options.dilation_w, options.padding_h, options.padding_w, options.fused_activation);
  132. return kcr_done;
  133. }
  134. kernel_call_result quantized_conv2d(quantized_conv2d_options &options, interpreter_t &interpreter, interpreter_step_t step)
  135. {
  136. auto input = interpreter.memory_at<uint8_t>(options.input);
  137. auto output = interpreter.memory_at<uint8_t>(options.output);
  138. kernels::neutral::quantized_conv2d(input.data(), output.data(), options.weights.data(), options.bias.data(), options.input_offset, options.filter_offset,
  139. options.output_mul, options.output_shift, options.output_offset, options.in_shape, options.groups, options.out_channels, options.filter_h,
  140. options.filter_w, options.stride_h, options.stride_w, options.dilation_h, options.dilation_w, options.padding_h, options.padding_w);
  141. return kcr_done;
  142. }
  143. kernel_call_result conv2d_transpose(conv2d_transpose_options &options, interpreter_t &interpreter, interpreter_step_t step)
  144. {
  145. auto input = interpreter.memory_at<float>(options.input);
  146. auto output = interpreter.memory_at<float>(options.output);
  147. kernels::neutral::conv2d_transpose(input.data(), output.data(), options.weights.data(), options.bias.data(), options.in_shape, options.groups, options.out_shape, options.filter_h,
  148. options.filter_w, options.stride_h, options.stride_w, options.dilation_h, options.dilation_w, options.padding_h, options.padding_w, options.fused_activation);
  149. return kcr_done;
  150. }
  151. kernel_call_result dequantize(dequantize_options &options, interpreter_t &interpreter, interpreter_step_t step)
  152. {
  153. auto input = interpreter.memory_at<uint8_t>(options.input);
  154. auto output = interpreter.memory_at<float>(options.output);
  155. kernels::neutral::dequantize(input.data(), output.data(), input.size(), options.quant_param);
  156. return kcr_done;
  157. }
  158. kernel_call_result matmul(matmul_options &options, interpreter_t &interpreter, interpreter_step_t step)
  159. {
  160. auto input_a = interpreter.memory_at<float>(options.input_a);
  161. auto input_b = interpreter.memory_at<float>(options.input_b);
  162. auto output = interpreter.memory_at<float>(options.output);
  163. kernels::neutral::matmul(input_a.data(), input_b.data(), output.data(), options.bias.data(), options.a_rows, options.a_cols, options.b_cols, options.fused_activation);
  164. return kcr_done;
  165. }
  166. kernel_call_result quantized_matmul(quantized_matmul_options &options, interpreter_t &interpreter, interpreter_step_t step)
  167. {
  168. auto input_a = interpreter.memory_at<uint8_t>(options.input_a);
  169. auto input_b = interpreter.memory_at<uint8_t>(options.input_b);
  170. auto output = interpreter.memory_at<uint8_t>(options.output);
  171. kernels::neutral::quantized_matmul(input_a.data(), input_b.data(), output.data(), options.bias.data(), options.a_rows, options.a_cols, options.b_cols,
  172. options.input_a_offset, options.input_b_offset, options.output_mul, options.output_shift, options.output_offset);
  173. return kcr_done;
  174. }
  175. kernel_call_result memory_copy(memory_copy_options &options, interpreter_t &interpreter, interpreter_step_t step)
  176. {
  177. auto input = interpreter.memory_at<float>(options.input);
  178. auto output = interpreter.memory_at<float>(options.output);
  179. std::copy(input.begin(), input.end(), output.begin());
  180. return kcr_done;
  181. }
  182. kernel_call_result pad(pad_options &options, interpreter_t &interpreter, interpreter_step_t step)
  183. {
  184. auto input = interpreter.memory_at<uint8_t>(options.input);
  185. auto output = interpreter.memory_at<uint8_t>(options.output);
  186. #define PAD_KERNEL(T) \
  187. kernels::neutral::pad(reinterpret_cast<const T *>(input.data()), reinterpret_cast<T *>(output.data()), options.in_shape, options.paddings, options.pad_value.as<T>());
  188. ELEM_SIZE_IMPL(options.input.datatype, PAD_KERNEL);
  189. return kcr_done;
  190. #undef PAD_KERNEL
  191. }
  192. kernel_call_result quantize(quantize_options &options, interpreter_t &interpreter, interpreter_step_t step)
  193. {
  194. auto input = interpreter.memory_at<float>(options.input);
  195. auto output = interpreter.memory_at<uint8_t>(options.output);
  196. kernels::neutral::quantize(input.data(), output.data(), input.size(), options.quant_param);
  197. return runtime::kcr_done;
  198. }
  199. kernel_call_result reduce(reduce_options &options, interpreter_t &interpreter, interpreter_step_t step)
  200. {
  201. auto input = interpreter.memory_at<float>(options.input);
  202. auto output = interpreter.memory_at<float>(options.output);
  203. auto reduce = [&](auto op) {
  204. kernels::neutral::reduce(input.data(), output.data(), options.init_value, options.in_shape, options.out_shape, op);
  205. };
  206. switch (options.reduce_op)
  207. {
  208. case reduce_mean:
  209. {
  210. reduce([](auto a, auto b) { return a + b; });
  211. auto mul = (float)output.size() / input.size();
  212. kernels::neutral::unary(output.data(), output.data(), output.size(), [mul](auto a) { return a * mul; });
  213. return kcr_done;
  214. }
  215. case reduce_min:
  216. reduce([](auto a, auto b) { return std::min(a, b); });
  217. return kcr_done;
  218. case reduce_max:
  219. reduce([](auto a, auto b) { return std::max(a, b); });
  220. return kcr_done;
  221. case reduce_sum:
  222. reduce([](auto a, auto b) { return a + b; });
  223. return kcr_done;
  224. default:
  225. return kcr_error;
  226. }
  227. }
  228. kernel_call_result reduce_window2d(reduce_window2d_options &options, interpreter_t &interpreter, interpreter_step_t step)
  229. {
  230. auto input = interpreter.memory_at<float>(options.input);
  231. auto output = interpreter.memory_at<float>(options.output);
  232. auto reduce = [&](auto binary_op, auto window_op) {
  233. kernels::neutral::reduce_window2d(input.data(), output.data(), options.init_value, options.in_shape, options.filter_h, options.filter_w, options.stride_h,
  234. options.stride_w, options.dilation_h, options.dilation_w, options.padding_h, options.padding_w, options.fused_activation, binary_op, window_op);
  235. };
  236. switch (options.reduce_op)
  237. {
  238. case reduce_mean:
  239. reduce([](auto a, auto b) { return a + b; }, [](auto v, auto k) { return v / k; });
  240. return kcr_done;
  241. case reduce_min:
  242. reduce([](auto a, auto b) { return std::min(a, b); }, [](auto v, auto k) { return v; });
  243. return kcr_done;
  244. case reduce_max:
  245. reduce([](auto a, auto b) { return std::max(a, b); }, [](auto v, auto k) { return v; });
  246. return kcr_done;
  247. case reduce_sum:
  248. reduce([](auto a, auto b) { return a + b; }, [](auto v, auto k) { return v; });
  249. return kcr_done;
  250. default:
  251. return kcr_error;
  252. }
  253. }
  254. kernel_call_result resize_image(resize_image_options &options, interpreter_t &interpreter, interpreter_step_t step)
  255. {
  256. auto input = interpreter.memory_at<uint8_t>(options.input);
  257. auto output = interpreter.memory_at<uint8_t>(options.output);
  258. if (options.mode == image_resize_bilinear)
  259. {
  260. #define RESIZE_BL_KERNEL(T) \
  261. kernels::neutral::resize_bilinear(reinterpret_cast<const T *>(input.data()), reinterpret_cast<T *>(output.data()), options.in_shape, options.out_h, options.out_w, options.align_corners);
  262. FP_OR_Q_IMPL(options.input.datatype, RESIZE_BL_KERNEL);
  263. return kcr_done;
  264. #undef RESIZE_BL_KERNEL
  265. }
  266. else
  267. {
  268. #define RESIZE_NN_KERNEL(T) \
  269. kernels::neutral::resize_nearest_neighbor(reinterpret_cast<const T *>(input.data()), reinterpret_cast<T *>(output.data()), options.in_shape, options.out_h, options.out_w);
  270. FP_OR_Q_IMPL(options.input.datatype, RESIZE_NN_KERNEL);
  271. return kcr_done;
  272. #undef RESIZE_NN_KERNEL
  273. }
  274. }
  275. kernel_call_result softmax(softmax_options &options, interpreter_t &interpreter, interpreter_step_t step)
  276. {
  277. auto input = interpreter.memory_at<float>(options.input);
  278. auto output = interpreter.memory_at<float>(options.output);
  279. kernels::neutral::softmax(input.data(), output.data(), options.beta, options.outer_size, options.inner_size);
  280. return kcr_done;
  281. }
  282. kernel_call_result transpose(transpose_options &options, interpreter_t &interpreter, interpreter_step_t step)
  283. {
  284. auto input = interpreter.memory_at<uint8_t>(options.input);
  285. auto output = interpreter.memory_at<uint8_t>(options.output);
  286. #define TRANSPOSE_KERNEL(T) \
  287. kernels::neutral::transpose(reinterpret_cast<const T *>(input.data()), reinterpret_cast<T *>(output.data()), options.in_shape, options.perm);
  288. ELEM_SIZE_IMPL(options.input.datatype, TRANSPOSE_KERNEL);
  289. return kcr_done;
  290. #undef TRANSPOSE_KERNEL
  291. }
  292. kernel_call_result strided_slice(strided_slice_options &options, interpreter_t &interpreter, interpreter_step_t step)
  293. {
  294. auto input = interpreter.memory_at<uint8_t>(options.input);
  295. auto output = interpreter.memory_at<uint8_t>(options.output);
  296. #define STRIDED_SLICE_KERNEL(T) \
  297. kernels::neutral::strided_slice(reinterpret_cast<const T *>(input.data()), reinterpret_cast<T *>(output.data()), options.in_shape, options.begin, options.end, options.strides);
  298. ELEM_SIZE_IMPL(options.input.datatype, STRIDED_SLICE_KERNEL);
  299. return kcr_done;
  300. #undef STRIDED_SLICE_KERNEL
  301. }
  302. kernel_call_result unary(unary_options &options, interpreter_t &interpreter, interpreter_step_t step)
  303. {
  304. auto input = interpreter.memory_at<float>(options.input);
  305. auto output = interpreter.memory_at<float>(options.output);
  306. auto unary = [&](auto unary_op) {
  307. kernels::neutral::unary(input.data(), output.data(), input.size(), unary_op);
  308. };
  309. switch (options.unary_op)
  310. {
  311. case unary_abs:
  312. unary([](auto a) { return fabs(a); });
  313. return kcr_done;
  314. case unary_ceil:
  315. unary([](auto a) { return ceilf(a); });
  316. return kcr_done;
  317. case unary_cos:
  318. unary([](auto a) { return cosf(a); });
  319. return kcr_done;
  320. case unary_exp:
  321. unary([](auto a) { return expf(a); });
  322. return kcr_done;
  323. case unary_floor:
  324. unary([](auto a) { return floorf(a); });
  325. return kcr_done;
  326. case unary_log:
  327. unary([](auto a) { return logf(a); });
  328. return kcr_done;
  329. case unary_neg:
  330. unary([](auto a) { return -a; });
  331. return kcr_done;
  332. case unary_rsqrt:
  333. unary([](auto a) { return 1.f / sqrtf(a); });
  334. return kcr_done;
  335. case unary_sin:
  336. unary([](auto a) { return sinf(a); });
  337. return kcr_done;
  338. case unary_square:
  339. unary([](auto a) { return a * a; });
  340. return kcr_done;
  341. default:
  342. return kcr_error;
  343. }
  344. }
  345. kernel_call_result nnil_unary_method(nnil_unary_method_options &options, interpreter_t &interpreter, interpreter_step_t step)
  346. {
  347. auto input = interpreter.memory_at<float>(options.input);
  348. auto output = interpreter.memory_at<float>(options.output);
  349. kernels::neutral::nnil_unary_method(input.data(), output.data(), input.size(), options.body);
  350. return kcr_done;
  351. }
  352. kernel_call_result table_lookup1d(table_lookup1d_options &options, interpreter_t &interpreter, interpreter_step_t step)
  353. {
  354. if (options.input.datatype != dt_uint8)
  355. return kcr_error;
  356. auto input = interpreter.memory_at<uint8_t>(options.input);
  357. auto table = interpreter.memory_at<uint8_t>(options.table);
  358. auto output = interpreter.memory_at<uint8_t>(options.output);
  359. kernels::neutral::table_lookup1d(input.data(), output.data(), input.size(), table.data());
  360. return kcr_done;
  361. }
  362. }
  363. }
  364. }