portable_tensor_utils.cc 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include <cmath>
  13. #include <algorithm>
  14. #include <cstdint>
  15. #include <cstring>
  16. #include <limits>
  17. #include <utility>
  18. #include "fixedpoint/fixedpoint.h"
  19. #include "tflite/c/builtin_op_data.h"
  20. #include "tflite/kernels/internal/common.h"
  21. #include "tflite/kernels/internal/compatibility.h"
  22. #include "tflite/kernels/internal/cppmath.h"
  23. #include "tflite/kernels/internal/reference/portable_tensor_utils_impl.h"
  24. #if defined(_MSC_VER)
  25. #define __restrict__ __restrict
  26. #endif
  27. namespace tflite {
  28. namespace tensor_utils {
  29. namespace {
  30. const int32_t kInt16Max = std::numeric_limits<int16_t>::max();
  31. const int32_t kInt16Min = std::numeric_limits<int16_t>::min();
  32. } // namespace
  33. void PortableSymmetricQuantizeFloats(const float* values, const int size,
  34. int8_t* quantized_values, float* min_value,
  35. float* max_value, float* scaling_factor) {
  36. auto minmax = std::minmax_element(values, values + size);
  37. *min_value = *minmax.first;
  38. *max_value = *minmax.second;
  39. PortableSymmetricQuantizeFloats(values, size, quantized_values, *min_value,
  40. *max_value, scaling_factor);
  41. }
  42. void PortableSymmetricQuantizeFloats(const float* values, const int size,
  43. int8_t* quantized_values, float min_value,
  44. float max_value, float* scaling_factor) {
  45. const int32_t kScale = 127;
  46. const float range = std::max(std::abs(min_value), std::abs(max_value));
  47. if (range == 0) {
  48. memset(quantized_values, 0, size * sizeof(int8_t));
  49. *scaling_factor = 1;
  50. return;
  51. }
  52. *scaling_factor = range / kScale;
  53. const float scaling_factor_inv = kScale / range;
  54. for (int i = 0; i < size; ++i) {
  55. const int32_t quantized_value =
  56. static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
  57. // Clamp: just in case some odd numeric offset.
  58. quantized_values[i] = static_cast<int8_t>(
  59. std::min(kScale, std::max(-kScale, quantized_value)));
  60. }
  61. }
  62. void PortableAsymmetricQuantizeFloats(const float* values, const int size,
  63. int8_t* quantized_values,
  64. float* scaling_factor, int32_t* offset) {
  65. const int32_t kMinScale = -128;
  66. const int32_t kMaxScale = 127;
  67. const double qmin_double = kMinScale;
  68. const double qmax_double = kMaxScale;
  69. const auto minmax = std::minmax_element(values, values + size);
  70. const double rmin = ::fmin(0, *minmax.first);
  71. const double rmax = ::fmax(0, *minmax.second);
  72. if (rmin == rmax) {
  73. memset(quantized_values, 0, size * sizeof(int8_t));
  74. *scaling_factor = 1;
  75. *offset = 0;
  76. return;
  77. } else {
  78. double scale = (rmax - rmin) / (qmax_double - qmin_double);
  79. const double zero_point_from_min = qmin_double - rmin / scale;
  80. const double zero_point_from_max = qmax_double - rmax / scale;
  81. const double zero_point_from_min_error =
  82. std::abs(qmin_double) + std::abs(rmin / scale);
  83. const double zero_point_from_max_error =
  84. std::abs(qmax_double) + std::abs(rmax / scale);
  85. const double zero_point_double =
  86. zero_point_from_min_error < zero_point_from_max_error
  87. ? zero_point_from_min
  88. : zero_point_from_max;
  89. int8 nudged_zero_point = 0;
  90. if (zero_point_double <= qmin_double) {
  91. nudged_zero_point = kMinScale;
  92. } else if (zero_point_double >= qmax_double) {
  93. nudged_zero_point = kMaxScale;
  94. } else {
  95. nudged_zero_point = static_cast<int8>(round(zero_point_double));
  96. }
  97. *scaling_factor = scale;
  98. *offset = nudged_zero_point;
  99. }
  100. const float scaling_factor_inv = 1.0 / *scaling_factor;
  101. for (int i = 0; i < size; ++i) {
  102. const int32_t quantized_value = static_cast<int32_t>(
  103. TfLiteRound(*offset + values[i] * scaling_factor_inv));
  104. quantized_values[i] =
  105. std::min(kMaxScale, std::max(kMinScale, quantized_value));
  106. }
  107. }
  108. void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
  109. int m_rows, int m_cols,
  110. const float* vector,
  111. int n_batch, float* result) {
  112. float* result_in_batch = result;
  113. for (int b = 0; b < n_batch; b++) {
  114. const float* matrix_ptr = matrix;
  115. for (int r = 0; r < m_rows; r++) {
  116. float dot_prod = 0.0f;
  117. const float* vector_in_batch = vector + b * m_cols;
  118. for (int c = 0; c < m_cols; c++) {
  119. dot_prod += *matrix_ptr++ * *vector_in_batch++;
  120. }
  121. *result_in_batch += dot_prod;
  122. ++result_in_batch;
  123. }
  124. }
  125. }
  126. void PortableMatrixBatchVectorMultiplyAccumulate(
  127. const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
  128. const int8_t* __restrict__ vectors, const float* scaling_factors,
  129. int n_batch, float* __restrict__ result) {
  130. for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
  131. const float batch_scaling_factor = scaling_factors[batch];
  132. // Get the address of the first row.
  133. const int8_t* row_ptr = matrix;
  134. for (int row = 0; row < m_rows; ++row) {
  135. // Initialize the dot product sum for the row to 0.
  136. int32_t dotprod = 0;
  137. #if defined(__GNUC__)
  138. // Prefetch the row to cache.
  139. __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
  140. 3 /* temporal locality */);
  141. #endif
  142. for (int col = 0; col < m_cols; ++col, ++row_ptr) {
  143. dotprod += (*row_ptr) * (vectors[col]);
  144. } // for col
  145. *result += dotprod * batch_scaling_factor;
  146. ++result;
  147. } // for row
  148. } // for batch
  149. }
  150. void PortableMatrixBatchVectorMultiplyAccumulate(
  151. const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
  152. const int8_t* __restrict__ vectors, const float* scaling_factors,
  153. int n_batch, float* __restrict__ result, const float* per_channel_scale,
  154. const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
  155. bool* compute_row_sums, CpuBackendContext* context) {
  156. if (input_offset == nullptr) {
  157. PortableMatrixBatchVectorMultiplyAccumulate(
  158. matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
  159. return;
  160. }
  161. if (!compute_row_sums || *compute_row_sums) {
  162. memset(row_sums, 0, sizeof(int32_t) * m_rows);
  163. PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
  164. if (compute_row_sums) {
  165. *compute_row_sums = false;
  166. }
  167. }
  168. for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
  169. const float batch_scaling_factor = scaling_factors[batch];
  170. const float batch_offset = input_offset[batch];
  171. const int8_t* row_ptr = matrix;
  172. for (int row = 0; row < m_rows; ++row) {
  173. int32_t dotprod = 0;
  174. float scale = batch_scaling_factor;
  175. if (per_channel_scale) {
  176. scale *= per_channel_scale[row];
  177. }
  178. #if defined(__GNUC__)
  179. // Prefetch the row to cache.
  180. __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
  181. 3 /* temporal locality */);
  182. #endif
  183. for (int col = 0; col < m_cols; ++col, ++row_ptr) {
  184. dotprod += (*row_ptr) * vectors[col];
  185. } // for col
  186. dotprod -= row_sums[row] * batch_offset;
  187. *result += dotprod * scale;
  188. ++result;
  189. } // for row
  190. } // for batch
  191. }
  192. void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
  193. const float* __restrict__ matrix, const int32_t* __restrict__ segments,
  194. const int32_t* __restrict__ indices, int m_rows, int m_cols,
  195. const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
  196. const int kBlockSize = 4;
  197. TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
  198. for (int batch = 0; batch < n_batch; batch++) {
  199. const float* matrix_ptr = matrix;
  200. for (int row = 0; row < m_rows; row++) {
  201. float dot_prod = 0.0f;
  202. const float* vector_in_batch = vector + batch * m_cols;
  203. for (int i = segments[row]; i < segments[row + 1]; i++) {
  204. const int block_start_index = indices[i] * kBlockSize;
  205. const float* vector_block_in_batch_ptr =
  206. vector_in_batch + block_start_index;
  207. for (int c = 0; c < kBlockSize; c++) {
  208. dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
  209. }
  210. }
  211. result[batch * m_rows + row] += dot_prod;
  212. }
  213. }
  214. }
  215. void PortableSparseMatrixBatchVectorMultiplyAccumulate(
  216. const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
  217. int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
  218. float* __restrict__ result) {
  219. const int kBlockSize = 16;
  220. TFLITE_DCHECK_EQ( // NOLINT
  221. m_cols % kBlockSize, 0);
  222. for (int batch = 0; batch < n_batch; batch++) {
  223. const float* matrix_ptr = matrix;
  224. const uint8_t* ledger_ptr = ledger;
  225. for (int row = 0; row < m_rows; row++) {
  226. float dot_prod = 0.0f;
  227. int num_nonzero_blocks = *ledger_ptr++;
  228. if (num_nonzero_blocks > 0) {
  229. const float* vector_in_batch = vector + batch * m_cols;
  230. for (int i = 0; i < num_nonzero_blocks; i++) {
  231. const int block_start_index = *ledger_ptr++ * kBlockSize;
  232. const float* vector_block_in_batch_ptr =
  233. vector_in_batch + block_start_index;
  234. for (int c = 0; c < kBlockSize; c++) {
  235. dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
  236. }
  237. }
  238. }
  239. result[batch * m_rows + row] += dot_prod;
  240. }
  241. }
  242. }
  243. void PortableSparseMatrixBatchVectorMultiplyAccumulate(
  244. const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
  245. const int m_cols, const int8_t* __restrict__ vectors,
  246. const float* scaling_factors, int n_batch, float* __restrict__ result) {
  247. static const int kBlockSize = 16;
  248. TFLITE_DCHECK_EQ( // NOLINT
  249. m_cols % kBlockSize, 0);
  250. for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
  251. const float batch_scaling_factor = scaling_factors[batch];
  252. const uint8_t* ledger_ptr = ledger;
  253. // Get the address of the first row.
  254. const int8_t* row_ptr = matrix;
  255. for (int row = 0; row < m_rows; ++row) {
  256. // Initialize the dot product sum for the row to 0.
  257. int32_t dotprod = 0;
  258. #if defined(__GNUC__)
  259. // Prefetch the row to cache.
  260. __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
  261. 3 /* temporal locality */);
  262. #endif
  263. int num_nonzero_blocks = *ledger_ptr++;
  264. for (int i = 0; i < num_nonzero_blocks; i++) {
  265. const int block_start_index = *ledger_ptr++ * kBlockSize;
  266. const int8_t* vector_block_ptr = vectors + block_start_index;
  267. for (int c = 0; c < kBlockSize; c++) {
  268. dotprod += (*row_ptr++) * (*vector_block_ptr++);
  269. } // for block
  270. } // for num_nonzero_blocks
  271. result[batch * m_rows + row] += dotprod * batch_scaling_factor;
  272. } // for row
  273. } // for batch
  274. }
  275. template <typename T>
  276. void PortableMatrixBatchVectorMultiplyAccumulateImpl(
  277. const int8_t* input, const int32_t* bias,
  278. const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
  279. int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
  280. T* output) {
  281. const int16_t output_max = std::numeric_limits<T>::max();
  282. const int16_t output_min = std::numeric_limits<T>::min();
  283. for (int batch = 0; batch < n_batch; ++batch) {
  284. for (int row = 0; row < n_output; ++row) {
  285. int32_t acc = bias[row];
  286. for (int col = 0; col < n_input; ++col) {
  287. int8 input_val = input[batch * n_input + col];
  288. int8 weights_val = input_to_gate_weights[row * n_input + col];
  289. acc += input_val * weights_val;
  290. }
  291. acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
  292. acc += output_zp;
  293. acc += output[batch * n_output + row];
  294. if (acc > output_max) {
  295. acc = output_max;
  296. }
  297. if (acc < output_min) {
  298. acc = output_min;
  299. }
  300. output[batch * n_output + row] = static_cast<T>(acc);
  301. }
  302. }
  303. }
  304. void PortableMatrixBatchVectorMultiplyAccumulate(
  305. const int8_t* input, const int32_t* bias,
  306. const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
  307. int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
  308. int32_t* scratch, int16_t* output, CpuBackendContext* context) {
  309. PortableMatrixBatchVectorMultiplyAccumulateImpl(
  310. input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
  311. n_output, output_zp, output);
  312. }
  313. void PortableMatrixBatchVectorMultiplyAccumulate(
  314. const int8_t* input, const int32_t* bias,
  315. const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
  316. int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
  317. int32_t* scratch, int8_t* output, CpuBackendContext* context) {
  318. PortableMatrixBatchVectorMultiplyAccumulateImpl(
  319. input, bias, input_to_gate_weights, multiplier, shift, n_batch, n_input,
  320. n_output, output_zp, output);
  321. }
  322. void PortableMatrixBatchVectorMultiply(const int8_t* input,
  323. int32_t input_zeropoint,
  324. const int8_t* input_to_gate_weights,
  325. int32_t input_to_gate_effective_scale_a,
  326. int32_t input_to_gate_effective_scale_b,
  327. int32_t n_batch, int32_t n_input,
  328. int32_t n_cell, int8_t* gate_output,
  329. int8_t gate_output_zp) {
  330. const int32_t int8_max = std::numeric_limits<int8>::max();
  331. const int32_t int8_min = std::numeric_limits<int8>::min();
  332. for (int batch = 0; batch < n_batch; ++batch) {
  333. for (int row = 0; row < n_cell; ++row) {
  334. int32_t acc = 0;
  335. for (int col = 0; col < n_input; ++col) {
  336. int32_t input_val = input[batch * n_input + col];
  337. int8_t weights_val = input_to_gate_weights[row * n_input + col];
  338. acc += (input_val - input_zeropoint) * weights_val;
  339. }
  340. acc = MultiplyByQuantizedMultiplier(acc, input_to_gate_effective_scale_a,
  341. input_to_gate_effective_scale_b);
  342. acc += gate_output_zp;
  343. if (acc > int8_max) {
  344. acc = int8_max;
  345. }
  346. if (acc < int8_min) {
  347. acc = int8_min;
  348. }
  349. gate_output[batch * n_cell + row] = static_cast<int8_t>(acc);
  350. }
  351. }
  352. }
  353. void PortableMatrixBatchVectorMultiply(
  354. const int16_t* hidden, const int8_t* hidden_to_output_weights,
  355. int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
  356. const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
  357. int32_t n_output, int32_t output_zp, int8_t* proj_output) {
  358. const int16_t int8_max = std::numeric_limits<int8>::max();
  359. const int16_t int8_min = std::numeric_limits<int8>::min();
  360. for (int batch = 0; batch < n_batch; ++batch) {
  361. for (int row = 0; row < n_output; ++row) {
  362. int64_t acc = gate_bias[row];
  363. for (int col = 0; col < n_hidden; ++col) {
  364. int16_t input_val = hidden[batch * n_hidden + col];
  365. int8_t weights_val = hidden_to_output_weights[row * n_hidden + col];
  366. int64_t curr = acc;
  367. acc += input_val * weights_val;
  368. if (input_val * weights_val > 0 && acc < curr) {
  369. acc = std::numeric_limits<int32>::max();
  370. }
  371. if (input_val * weights_val < 0 && acc > curr) {
  372. acc = std::numeric_limits<int32>::min();
  373. }
  374. }
  375. acc = MultiplyByQuantizedMultiplier(acc, proj_effective_scale_a,
  376. proj_effective_scale_b);
  377. acc += output_zp;
  378. if (acc > int8_max) {
  379. acc = int8_max;
  380. }
  381. if (acc < int8_min) {
  382. acc = int8_min;
  383. }
  384. proj_output[batch * n_output + row] = acc;
  385. }
  386. }
  387. }
  388. void PortableApplyLayerNorm(const int16_t* input,
  389. const int16_t* layer_norm_weights,
  390. const int32_t* bias, int32_t layer_norm_scale_a,
  391. int32_t layer_norm_scale_b, int32_t variance_limit,
  392. int n_batch, int n_input, int16_t* output) {
  393. // The square of std::pow(2, 10), which is the extra factor that makes sure
  394. // normalized values has enough resolution.
  395. static const int kTwoToPower20 = 1 << 20;
  396. for (int i = 0; i < n_batch; ++i) {
  397. int64_t sum = 0;
  398. int64_t sum_sq = 0;
  399. for (int j = 0; j < n_input; ++j) {
  400. const int32_t index = i * n_input + j;
  401. int32_t val = static_cast<int32_t>(input[index]);
  402. sum += val;
  403. sum_sq += val * val;
  404. }
  405. int32_t mean =
  406. static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
  407. // TODO(jianlijianli): Avoids overflow but only works for POT n_input.
  408. int32 temp = kTwoToPower20 / n_input;
  409. int64_t variance =
  410. sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
  411. int32_t variance2 = static_cast<int32>(variance / kTwoToPower20);
  412. if (variance2 < 1) {
  413. variance2 = variance_limit;
  414. }
  415. int32_t stddev_inverse_a;
  416. int stddev_inverse_b;
  417. GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1,
  418. &stddev_inverse_a, &stddev_inverse_b);
  419. for (int j = 0; j < n_input; ++j) {
  420. const int32 index = i * n_input + j;
  421. int32 val = static_cast<int32_t>(input[index]);
  422. int32 shifted = 1024 * val - mean;
  423. int32 rescaled = MultiplyByQuantizedMultiplier(shifted, stddev_inverse_a,
  424. stddev_inverse_b);
  425. // TODO(jianlijianli): Saturate this.
  426. int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
  427. int32 val4 =
  428. static_cast<int32>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
  429. int32 val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
  430. layer_norm_scale_b + 12);
  431. val5 = std::min(std::max(kInt16Min, val5), kInt16Max);
  432. output[index] = static_cast<int16_t>(val5);
  433. }
  434. }
  435. }
  436. void PortableApplyLayerNormFloat(const int16_t* input,
  437. const int16_t* layer_norm_weights,
  438. int32_t layer_norm_scale_a,
  439. int32_t layer_norm_scale_b,
  440. const int32_t* bias, int n_batch, int n_input,
  441. int16_t* output) {
  442. const int32_t int16_max = std::numeric_limits<int16>::max();
  443. const int32_t int16_min = std::numeric_limits<int16>::min();
  444. // This is to surpress a lint warning.
  445. const double two = 2.0;
  446. const float layer_norm_scale =
  447. layer_norm_scale_a *
  448. std::pow(two, static_cast<double>(layer_norm_scale_b - 31));
  449. const float bias_scale = std::pow(two, -10) * layer_norm_scale;
  450. for (int batch = 0; batch < n_batch; ++batch) {
  451. float sum = 0.0f;
  452. float sum_sq = 0.0f;
  453. for (int i = 0; i < n_input; ++i) {
  454. const int index = batch * n_input + i;
  455. const float value = static_cast<float>(input[index]);
  456. sum += value;
  457. sum_sq += value * value;
  458. }
  459. const float mean = sum / n_input;
  460. float stddev_inv = 0.0f;
  461. const float variance = sum_sq / n_input - mean * mean;
  462. if (variance == 0) {
  463. stddev_inv = 1.0f / sqrt(1e-8);
  464. } else {
  465. stddev_inv = 1.0f / sqrt(variance);
  466. }
  467. for (int i = 0; i < n_input; ++i) {
  468. const int index = batch * n_input + i;
  469. const float normalized_value =
  470. (static_cast<float>(input[index]) - mean) * stddev_inv;
  471. const float weighted_normalized_value =
  472. normalized_value * layer_norm_weights[i] * layer_norm_scale +
  473. bias[i] * bias_scale;
  474. const int32_t quant_output = static_cast<int32>(
  475. ::round(weighted_normalized_value * std::pow(2, 12)));
  476. output[index] = std::min(int16_max, std::max(int16_min, quant_output));
  477. }
  478. }
  479. }
  480. void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
  481. int32_t scalar, int32_t n_row,
  482. int32_t n_col, int32_t* output) {
  483. for (int i = 0; i < n_row; ++i) {
  484. int32_t row_sum = 0;
  485. for (int j = 0; j < n_col; ++j) {
  486. row_sum += *matrix++;
  487. }
  488. output[i] += row_sum * scalar;
  489. }
  490. }
  491. void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
  492. int32_t n_input, int16_t* output) {
  493. for (int batch = 0; batch < n_batch; ++batch) {
  494. for (int c = 0; c < n_input; c++) {
  495. using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
  496. using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
  497. const int index = batch * n_input + c;
  498. F3 sigmoid_input = F3::FromRaw(input[index]);
  499. F0 sigmoid_output = gemmlowp::logistic(sigmoid_input);
  500. output[index] = sigmoid_output.raw();
  501. }
  502. }
  503. }
  504. void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
  505. int32_t n_input, int16_t* output) {
  506. const int32_t int16_max = std::numeric_limits<int16>::max();
  507. const int32_t int16_min = std::numeric_limits<int16>::min();
  508. for (int batch = 0; batch < n_batch; ++batch) {
  509. for (int i = 0; i < n_input; ++i) {
  510. const int index = batch * n_input + i;
  511. const float float_input = input[index] * std::pow(2, -12);
  512. const float float_output = 1.0f / (1.0f + std::exp(-float_input));
  513. const int32_t quant_output =
  514. static_cast<int32>(float_output * std::pow(2, 15));
  515. const int32_t quant_output_clamped =
  516. std::min(int16_max, std::max(int16_min, quant_output));
  517. output[index] = static_cast<int16>(quant_output_clamped);
  518. }
  519. }
  520. }
  521. template <int IntegerBits>
  522. void PortableApplyTanhImpl(const int16_t* input, int32_t n_batch,
  523. int32_t n_input, int16_t* output) {
  524. using FX = gemmlowp::FixedPoint<std::int16_t, IntegerBits>;
  525. using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
  526. for (int batch = 0; batch < n_batch; ++batch) {
  527. for (int i = 0; i < n_input; ++i) {
  528. const int index = batch * n_input + i;
  529. FX tanh_input = FX::FromRaw(input[index]);
  530. F0 tanh_output = gemmlowp::tanh(tanh_input);
  531. output[index] = tanh_output.raw();
  532. }
  533. }
  534. }
  535. void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
  536. int32_t n_batch, int32_t n_input, int16_t* output) {
  537. assert(integer_bits <= 6);
  538. #define DISPATCH_TANH(i) \
  539. case i: \
  540. PortableApplyTanhImpl<i>(input, n_batch, n_input, output); \
  541. break;
  542. switch (integer_bits) {
  543. DISPATCH_TANH(0);
  544. DISPATCH_TANH(1);
  545. DISPATCH_TANH(2);
  546. DISPATCH_TANH(3);
  547. DISPATCH_TANH(4);
  548. DISPATCH_TANH(5);
  549. DISPATCH_TANH(6);
  550. default:
  551. return;
  552. }
  553. #undef DISPATCH_TANH
  554. }
  555. void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
  556. int32_t n_input, int32_t integer_bits,
  557. int16_t* output) {
  558. const int32_t int16_max = std::numeric_limits<int16>::max();
  559. const int32_t int16_min = std::numeric_limits<int16>::min();
  560. const double two = 2.0;
  561. for (int batch = 0; batch < n_batch; ++batch) {
  562. for (int i = 0; i < n_input; ++i) {
  563. const int index = batch * n_input + i;
  564. const float float_input =
  565. input[index] * std::pow(two, static_cast<double>(integer_bits));
  566. const float float_output = std::tanh(float_input);
  567. const int32_t quant_output =
  568. static_cast<int32>(float_output * std::pow(2, 15));
  569. const int32_t quant_output_clamped =
  570. std::min(int16_max, std::max(int16_min, quant_output));
  571. output[index] = static_cast<int16>(quant_output_clamped);
  572. }
  573. }
  574. }
  575. void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
  576. int n_batch, int n_input, int shift, int16_t* output) {
  577. for (int batch = 0; batch < n_batch; ++batch) {
  578. for (int i = 0; i < n_input; ++i) {
  579. const int index = batch * n_input + i;
  580. const int16_t a = input_1[index];
  581. const int16_t b = input_2[index];
  582. const int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
  583. output[index] =
  584. static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
  585. }
  586. }
  587. }
  588. void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
  589. int32_t multiplier, int32_t shift, int32_t n_batch,
  590. int32_t n_input, int32_t output_zp, int8_t* output) {
  591. for (int batch = 0; batch < n_batch; ++batch) {
  592. for (int i = 0; i < n_input; ++i) {
  593. const int index = batch * n_input + i;
  594. const int16_t a = input_1[index];
  595. const int16_t b = input_2[index];
  596. int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
  597. value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
  598. value -= output_zp;
  599. value = std::min(std::max(static_cast<int32_t>(-128), value),
  600. static_cast<int32_t>(127));
  601. output[index] = static_cast<int8>(value);
  602. }
  603. }
  604. }
  605. void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
  606. int n_batch, int n_input, int16_t* output) {
  607. for (int batch = 0; batch < n_batch; ++batch) {
  608. for (int i = 0; i < n_input; ++i) {
  609. const int index = batch * n_input + i;
  610. int32_t sum = input_1[index] + input_2[index];
  611. const int32 sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum));
  612. output[index] = static_cast<int16_t>(sum_clamped);
  613. }
  614. }
  615. }
  616. void PortableCwiseClipping(int16_t* input, const int16_t clipping_value,
  617. int32_t n_batch, int32_t n_input) {
  618. for (int batch = 0; batch < n_batch; ++batch) {
  619. for (int i = 0; i < n_input; ++i) {
  620. const int index = batch * n_input + i;
  621. if (input[index] > clipping_value) {
  622. input[index] = clipping_value;
  623. }
  624. if (input[index] < -clipping_value) {
  625. input[index] = -clipping_value;
  626. }
  627. }
  628. }
  629. }
  630. void PortableCwiseClipping(int8_t* input, const int8_t clipping_value,
  631. int32_t n_batch, int32_t n_input) {
  632. for (int batch = 0; batch < n_batch; ++batch) {
  633. for (int i = 0; i < n_input; ++i) {
  634. const int index = batch * n_input + i;
  635. if (input[index] > clipping_value) {
  636. input[index] = clipping_value;
  637. }
  638. if (input[index] < -clipping_value) {
  639. input[index] = -clipping_value;
  640. }
  641. }
  642. }
  643. }
  644. float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
  645. int v_size) {
  646. float result = 0.0;
  647. for (int v = 0; v < v_size; v++) {
  648. result += *vector1++ * *vector2++;
  649. }
  650. return result;
  651. }
  652. namespace {
  653. inline int32_t VectorVectorDotProduct(const int16_t* vector1,
  654. const int16_t* vector2, int v_size) {
  655. int32_t result = 0;
  656. for (int v = 0; v < v_size; v++) {
  657. result += *vector1++ * *vector2++;
  658. }
  659. return result;
  660. }
  661. } // namespace
  662. void PortableBatchVectorBatchVectorDotProduct(const int16_t* vector1,
  663. const int16_t* vector2,
  664. int v_size, int n_batch,
  665. int32_t* result) {
  666. for (int b = 0; b < n_batch; b++) {
  667. result[b] = VectorVectorDotProduct(vector1, vector2, v_size);
  668. vector1 += v_size;
  669. vector2 += v_size;
  670. }
  671. }
  672. void PortableVectorBatchVectorCwiseProductAccumulate(
  673. const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
  674. int32_t multiplier, int shift, int16_t* result) {
  675. for (int b = 0; b < n_batch; b++) {
  676. for (int v = 0; v < v_size; v++) {
  677. int32_t prod = vector[v] * *batch_vector++;
  678. prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
  679. int32_t output = prod + *result;
  680. output = std::max(std::min(static_cast<int32_t>(32767), output),
  681. static_cast<int32_t>(-32768));
  682. *result++ = output;
  683. }
  684. }
  685. }
  686. void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
  687. float* batch_vector) {
  688. for (int b = 0; b < n_batch; b++) {
  689. for (int i = 0; i < v_size; ++i) {
  690. batch_vector[i] += vector[i];
  691. }
  692. batch_vector += v_size;
  693. }
  694. }
  695. void PortableSub1Vector(const float* vector, int v_size, float* result) {
  696. for (int v = 0; v < v_size; v++) {
  697. *result++ = 1.0f - *vector++;
  698. }
  699. }
  700. void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
  701. static const int16_t kOne = 32767;
  702. for (int v = 0; v < v_size; v++) {
  703. *result++ = kOne - *vector++;
  704. }
  705. }
  706. void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
  707. const float scale, float* result) {
  708. for (int v = 0; v < v_size; ++v) {
  709. *result++ = scale * *vector++;
  710. }
  711. }
  712. void PortableClipVector(const float* vector, int v_size, float abs_limit,
  713. float* result) {
  714. for (int v = 0; v < v_size; v++) {
  715. result[v] = std::max(std::min(abs_limit, vector[v]), -abs_limit);
  716. }
  717. }
  718. void PortableReductionSumVector(const float* input_vector, float* output_vector,
  719. int output_size, int reduction_size) {
  720. const float* input_vector_ptr = input_vector;
  721. for (int o = 0; o < output_size; o++) {
  722. for (int r = 0; r < reduction_size; r++) {
  723. output_vector[o] += *input_vector_ptr++;
  724. }
  725. }
  726. }
  727. void PortableReductionSumVector(const int32_t* input_vector,
  728. int32_t* output_vector, int output_size,
  729. int reduction_size) {
  730. const int32_t* input_vector_ptr = input_vector;
  731. for (int o = 0; o < output_size; o++) {
  732. for (int r = 0; r < reduction_size; r++) {
  733. output_vector[o] += *input_vector_ptr++;
  734. }
  735. }
  736. }
  737. void PortableReductionSumVector(const int8_t* input_vector,
  738. int32_t* output_vector, int output_size,
  739. int reduction_size) {
  740. const int8_t* input_vector_ptr = input_vector;
  741. for (int o = 0; o < output_size; o++) {
  742. for (int r = 0; r < reduction_size; r++) {
  743. output_vector[o] += *input_vector_ptr++;
  744. }
  745. }
  746. }
  747. void PortableMeanStddevNormalization(const float* input_vector,
  748. float* output_vector, int v_size,
  749. int n_batch) {
  750. for (int batch = 0; batch < n_batch; ++batch) {
  751. float sum = 0.0f;
  752. for (int i = 0; i < v_size; ++i) {
  753. sum += input_vector[i];
  754. }
  755. const float mean = sum / v_size;
  756. float sum_diff_sq = 0.0f;
  757. for (int i = 0; i < v_size; ++i) {
  758. const float diff = input_vector[i] - mean;
  759. sum_diff_sq += diff * diff;
  760. }
  761. const float variance = sum_diff_sq / v_size;
  762. constexpr float kNormalizationConstant = 1e-8f;
  763. const float stddev_inv =
  764. 1.0f / std::sqrt(variance + kNormalizationConstant);
  765. for (int i = 0; i < v_size; ++i) {
  766. output_vector[i] = (input_vector[i] - mean) * stddev_inv;
  767. }
  768. input_vector += v_size;
  769. output_vector += v_size;
  770. }
  771. }
  772. void PortableTwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
  773. const int8_t* recurrent, int8_t recurrent_zp,
  774. int32_t input_effective_scale_a,
  775. int32_t input_effective_scale_b,
  776. int32_t recurrent_effective_scale_a,
  777. int32_t recurrent_effective_scale_b,
  778. int32_t n_batch, int32_t n_cell,
  779. int16_t* output) {
  780. const int32_t int16_max = std::numeric_limits<int16>::max();
  781. const int32_t int16_min = std::numeric_limits<int16>::min();
  782. for (int i = 0; i < n_batch * n_cell; ++i) {
  783. int32_t x = static_cast<int32>(input[i]) - static_cast<int32>(input_zp);
  784. int32_t h =
  785. static_cast<int32>(recurrent[i]) - static_cast<int32>(recurrent_zp);
  786. int32_t x_scaled = MultiplyByQuantizedMultiplier(x, input_effective_scale_a,
  787. input_effective_scale_b);
  788. int32_t h_scaled = MultiplyByQuantizedMultiplier(
  789. h, recurrent_effective_scale_a, recurrent_effective_scale_b);
  790. int32_t y = h_scaled + x_scaled;
  791. if (y > int16_max) {
  792. y = int16_max;
  793. }
  794. if (y < int16_min) {
  795. y = int16_min;
  796. }
  797. output[i] = static_cast<int16_t>(y);
  798. }
  799. }
  800. } // namespace tensor_utils
  801. } // namespace tflite