| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
- #define TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
- #include <algorithm>
- #include <limits>
- #include "flatbuffers/flatbuffers.h"
- #include "tensorflow/lite/c/builtin_op_data.h"
- #include "tensorflow/lite/c/common.h"
- namespace tflite {
- inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
- inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
- return t->dims->data[dim];
- }
- inline const TfLiteTensor* GetInput(const TfLiteContext* context,
- const TfLiteNode* node, int index) {
- return &context->tensors[node->inputs->data[index]];
- }
- // Note: You must check if result is not null:
- // TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
- // TF_LITE_ENSURE(context, my_tensor != nullptr);
- inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
- const TfLiteNode* node, int index) {
- TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
- return (tensor->is_variable) ? tensor : nullptr;
- }
- inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
- int index) {
- return &context->tensors[node->outputs->data[index]];
- }
- inline TfLiteTensor* GetTemporary(TfLiteContext* context,
- const TfLiteNode* node, int index) {
- return &context->tensors[node->temporaries->data[index]];
- }
- inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
- const TfLiteNode* node, int index) {
- return &context->tensors[node->intermediates->data[index]];
- }
- inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
- inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
- inline int NumIntermediates(const TfLiteNode* node) {
- return node->intermediates->size;
- }
- inline int64_t NumElements(const TfLiteIntArray* dims) {
- int64_t count = 1;
- for (int i = 0; i < dims->size; ++i) {
- count *= dims->data[i];
- }
- return count;
- }
- inline int64_t NumElements(const TfLiteTensor* t) {
- return NumElements(t->dims);
- }
- inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
- const TfLiteNode* node,
- int index) {
- const bool use_tensor = index < node->inputs->size &&
- node->inputs->data[index] != kTfLiteOptionalTensor;
- if (use_tensor) {
- return &context->tensors[node->inputs->data[index]];
- }
- return nullptr;
- }
- // Determines whether tensor is constant.
- // TODO(b/138199592): Introduce new query which checks for constant OR
- // persistent-read-only, which would be useful for most tensor kernels that
- // are potentially dynamic based on the input tensor value availability at the
- // time of prepare.
- inline bool IsConstantTensor(const TfLiteTensor* tensor) {
- return tensor->allocation_type == kTfLiteMmapRo;
- }
- // Determines whether tensor is dynamic. Note that a tensor can be non-const and
- // not dynamic. This function specifically checks for a dynamic tensor.
- inline bool IsDynamicTensor(const TfLiteTensor* tensor) {
- return tensor->allocation_type == kTfLiteDynamic;
- }
- // Sets tensor to dynamic.
- inline void SetTensorToDynamic(TfLiteTensor* tensor) {
- if (tensor->allocation_type != kTfLiteDynamic) {
- tensor->allocation_type = kTfLiteDynamic;
- tensor->data.raw = nullptr;
- }
- }
- // Sets tensor to persistent and read-only.
- inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
- if (tensor->allocation_type != kTfLitePersistentRo) {
- tensor->allocation_type = kTfLitePersistentRo;
- tensor->data.raw = nullptr;
- }
- }
- // Determines whether it is a hybrid op - one that has float inputs and
- // quantized weights.
- inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
- return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) &&
- input->type == kTfLiteFloat32);
- }
- // Check dimensionality match and populate OpData for Conv and DepthwiseConv.
- TfLiteStatus PopulateConvolutionQuantizationParams(
- TfLiteContext* context, const TfLiteTensor* input,
- const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
- const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
- int32_t* output_activation_min, int32_t* output_activation_max,
- int32_t* per_channel_multiplier, int* per_channel_shift);
- TfLiteStatus PopulateConvolutionQuantizationParams(
- TfLiteContext* context, const TfLiteTensor* input,
- const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
- const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
- int32_t* output_activation_min, int32_t* output_activation_max,
- int32_t* per_channel_multiplier, int* per_channel_shift, int num_channels);
- // Calculates the multiplication factor for a quantized convolution (or
- // quantized depthwise convolution) involving the given tensors. Returns an
- // error if the scales of the tensors are not compatible.
- TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
- const TfLiteTensor* input,
- const TfLiteTensor* filter,
- const TfLiteTensor* bias,
- TfLiteTensor* output,
- double* multiplier);
- TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
- const TfLiteTensor* input,
- const TfLiteTensor* filter,
- TfLiteTensor* output,
- double* multiplier);
- // Calculates the useful quantized range of an activation layer given its
- // activation tensor.
- TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
- TfLiteFusedActivation activation,
- TfLiteTensor* output,
- int32_t* act_min,
- int32_t* act_max);
- // Calculates the useful range of an activation layer given its activation
- // tensor.a
- template <typename T>
- void CalculateActivationRange(TfLiteFusedActivation activation,
- T* activation_min, T* activation_max) {
- if (activation == kTfLiteActRelu) {
- *activation_min = 0;
- *activation_max = std::numeric_limits<T>::max();
- } else if (activation == kTfLiteActRelu6) {
- *activation_min = 0;
- *activation_max = 6;
- } else if (activation == kTfLiteActRelu1) {
- *activation_min = -1;
- *activation_max = 1;
- } else {
- *activation_min = std::numeric_limits<T>::lowest();
- *activation_max = std::numeric_limits<T>::max();
- }
- }
- // Return true if the given tensors have the same shape.
- bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
- // Calculates the output_shape that is necessary for element-wise operations
- // with broadcasting involving the two input tensors.
- TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2,
- TfLiteIntArray** output_shape);
- // Calculates the output_shape that is necessary for element-wise operations
- // with broadcasting involving the three input tensors.
- TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2,
- const TfLiteTensor* input3,
- TfLiteIntArray** output_shape);
- } // namespace tflite
- #endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|