| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
- #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/kernels/internal/common.h"
- #include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
- #include "tensorflow/lite/kernels/internal/types.h"
- namespace tflite {
- namespace reference_ops {
- inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
- const RuntimeShape& rhs_shape, const float* rhs_data,
- const RuntimeShape& output_shape, float* output_data) {
- const RuntimeShape extended_lhs_shape =
- RuntimeShape::ExtendedShape(5, lhs_shape);
- const RuntimeShape extended_rhs_shape =
- RuntimeShape::ExtendedShape(5, rhs_shape);
- // Determine which dimension is the broadcast dimension.
- auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
- if (lhs_dim == rhs_dim) return lhs_dim;
- if (lhs_dim == 1) return rhs_dim;
- TFLITE_DCHECK_EQ(rhs_dim, 1);
- return lhs_dim;
- };
- // Compute the "extent" for iterating on this dimension.
- // If we are broadcasting, then don't advance (i.e return 0).
- auto extent = [](const RuntimeShape& shape, int x) {
- if (shape.Dims(x) == 1) {
- return 0;
- }
- int prod = 1;
- for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
- prod *= shape.Dims(i);
- }
- return prod;
- };
- const int batch_dim0 =
- broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
- const int batch_dim1 =
- broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
- const int batch_dim2 =
- broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
- const int lhs_ext0 = extent(extended_lhs_shape, 0);
- const int lhs_ext1 = extent(extended_lhs_shape, 1);
- const int lhs_ext2 = extent(extended_lhs_shape, 2);
- const int rhs_ext0 = extent(extended_rhs_shape, 0);
- const int rhs_ext1 = extent(extended_rhs_shape, 1);
- const int rhs_ext2 = extent(extended_rhs_shape, 2);
- // Set params for each matrix multiply.
- const int lhs_rows = extended_lhs_shape.Dims(3);
- const int rhs_cols = extended_rhs_shape.Dims(4);
- const int accum_depth = extended_lhs_shape.Dims(4);
- for (int b0 = 0; b0 < batch_dim0; ++b0) {
- const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
- const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
- for (int b1 = 0; b1 < batch_dim1; ++b1) {
- const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
- const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
- for (int b2 = 0; b2 < batch_dim2; ++b2) {
- const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
- const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
- float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
- b1 * batch_dim2 + b2) *
- lhs_rows * rhs_cols;
- for (int j = 0; j < rhs_cols; ++j) {
- for (int i = 0; i < lhs_rows; ++i) {
- float total = 0.f;
- for (int k = 0; k < accum_depth; ++k) {
- total +=
- lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
- }
- int idx = lhs_rows * j + i;
- out_ptr[idx] = total;
- }
- }
- }
- }
- }
- }
- inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
- const RuntimeShape& rhs_shape, const int8_t* rhs_data,
- const float* scaling_factors,
- const int32_t* input_offset, int32_t* row_sums,
- const RuntimeShape& output_shape, float* output_data,
- bool* compute_row_sums) {
- const RuntimeShape extended_lhs_shape =
- RuntimeShape::ExtendedShape(5, lhs_shape);
- const RuntimeShape extended_rhs_shape =
- RuntimeShape::ExtendedShape(5, rhs_shape);
- // Determine which dimension is the broadcast dimension.
- auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
- if (lhs_dim == rhs_dim) return lhs_dim;
- if (lhs_dim == 1) return rhs_dim;
- TFLITE_DCHECK_EQ(rhs_dim, 1);
- return lhs_dim;
- };
- // Compute the "extent" for iterating on this dimension.
- // If we are broadcasting, then don't advance (i.e return 0).
- auto extent = [](const RuntimeShape& shape, int x) {
- if (shape.Dims(x) == 1) {
- return 0;
- }
- int prod = 1;
- for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
- prod *= shape.Dims(i);
- }
- return prod;
- };
- const int batch_dim0 =
- broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
- const int batch_dim1 =
- broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
- const int batch_dim2 =
- broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
- const int lhs_ext0 = extent(extended_lhs_shape, 0);
- const int lhs_ext1 = extent(extended_lhs_shape, 1);
- const int lhs_ext2 = extent(extended_lhs_shape, 2);
- const int rhs_ext0 = extent(extended_rhs_shape, 0);
- const int rhs_ext1 = extent(extended_rhs_shape, 1);
- const int rhs_ext2 = extent(extended_rhs_shape, 2);
- // Set params for each matrix multiply.
- const int lhs_rows = extended_lhs_shape.Dims(3);
- const int rhs_cols = extended_rhs_shape.Dims(4);
- const int accum_depth = extended_lhs_shape.Dims(4);
- const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
- const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
- const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
- const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
- const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
- const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
- if (!compute_row_sums || *compute_row_sums) {
- int num_weights_matrices = 1;
- for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
- num_weights_matrices *= extended_lhs_shape.Dims(i);
- }
- memset(row_sums, 0, sizeof(int32_t) * lhs_rows * num_weights_matrices);
- for (int j = 0; j < num_weights_matrices; ++j) {
- tensor_utils::PortableReductionSumVector(
- lhs_data + j * lhs_rows * accum_depth, row_sums + j * lhs_rows,
- lhs_rows, accum_depth);
- }
- if (compute_row_sums) {
- *compute_row_sums = false;
- }
- }
- for (int b0 = 0; b0 < batch_dim0; ++b0) {
- const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
- const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
- const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
- const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
- const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
- for (int b1 = 0; b1 < batch_dim1; ++b1) {
- const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
- const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
- const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
- const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
- const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
- for (int b2 = 0; b2 < batch_dim2; ++b2) {
- const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
- const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
- const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
- const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
- const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
- float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
- b1 * batch_dim2 + b2) *
- lhs_rows * rhs_cols;
- for (int j = 0; j < rhs_cols; ++j) {
- const float batch_scaling_factor = scale_ptr2[j];
- const float batch_offset = static_cast<float>(ioff_ptr2[j]);
- for (int i = 0; i < lhs_rows; ++i) {
- int32_t total = 0;
- for (int k = 0; k < accum_depth; ++k) {
- total +=
- lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
- }
- int32_t row_sum = woff_ptr2[i];
- total -= row_sum * batch_offset;
- int idx = lhs_rows * j + i;
- out_ptr[idx] += batch_scaling_factor * total;
- }
- }
- }
- }
- }
- }
- } // namespace reference_ops
- } // namespace tflite
- #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|