neutral_kernels.h 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /* Copyright 2019-2020 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #pragma once
  16. #include "../kernel_utils.h"
  17. #include <cmath>
  18. #include <runtime/runtime_op_utility.h>
  19. #include <xtl/xspan.hpp>
  20. namespace nncase
  21. {
  22. namespace kernels
  23. {
  24. namespace neutral
  25. {
  26. template <class TQ>
  27. void riscv_dequantize(const TQ *CXX_RESTRICT input, float *CXX_RESTRICT output, size_t count, const quant_param_t &param)
  28. {
  29. float scale = 1.f / param.scale;
  30. float zero = -param.zero_point * scale;
  31. for (size_t i = 0; i < count / 2; i++)
  32. {
  33. // handwritten pipeline for in order CPU
  34. auto in1_q = input[i * 2];
  35. auto in2_q = input[i * 2 + 1];
  36. auto in1 = (float)in1_q;
  37. auto in2 = (float)in2_q;
  38. auto out1 = in1 * scale + zero;
  39. auto out2 = in2 * scale + zero;
  40. output[i * 2] = out1;
  41. output[i * 2 + 1] = out2;
  42. }
  43. if (count % 2)
  44. output[count - 1] = input[count - 1] * scale + zero;
  45. }
  46. template <class TQ>
  47. void riscv_quantize(const float *CXX_RESTRICT input, TQ *CXX_RESTRICT output, size_t count, const quant_param_t &param)
  48. {
  49. float scale = param.scale;
  50. float zero = param.zero_point;
  51. for (size_t i = 0; i < count / 2; i++)
  52. {
  53. auto in1 = input[i * 2];
  54. auto in2 = input[i * 2 + 1];
  55. in1 = in1 * scale + zero;
  56. in2 = in2 * scale + zero;
  57. int32_t out1, out2;
  58. asm volatile("fcvt.w.s %0, %1, rne"
  59. : "=r"(out1)
  60. : "f"(in1));
  61. asm volatile("fcvt.w.s %0, %1, rne"
  62. : "=r"(out2)
  63. : "f"(in2));
  64. output[i * 2] = std::clamp(out1, (int32_t)std::numeric_limits<TQ>::lowest(), (int32_t)std::numeric_limits<TQ>::max());
  65. output[i * 2 + 1] = std::clamp(out2, (int32_t)std::numeric_limits<TQ>::lowest(), (int32_t)std::numeric_limits<TQ>::max());
  66. }
  67. if (count % 2)
  68. {
  69. auto in = (int32_t)roundf(input[count - 1] * scale + zero);
  70. output[count - 1] = std::clamp(in, (int32_t)std::numeric_limits<TQ>::lowest(), (int32_t)std::numeric_limits<TQ>::max());
  71. }
  72. }
  73. }
  74. }
  75. }