| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include <cmath>
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/micro/kernels/kernel_util.h"
- #include "tensorflow/lite/micro/micro_utils.h"
- namespace tflite {
- namespace ops {
- namespace micro {
- namespace elementwise {
- namespace {
- bool IsNumericSupportedType(const TfLiteType type) {
- return type == kTfLiteFloat32;
- }
- bool IsLogicalSupportedType(const TfLiteType type) {
- return type == kTfLiteBool;
- }
- typedef bool (*IsSupportedType)(TfLiteType);
- template <IsSupportedType>
- TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, 0);
- TfLiteTensor* output = GetOutput(context, node, 0);
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
- if (!IsSupportedType(input->type)) {
- TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
- return kTfLiteOk;
- }
- template <typename T>
- inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
- T func(T), TfLiteType expected_type) {
- const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
- TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
- const size_t num_elements = ElementCount(*input->dims);
- const T* in_data = tflite::micro::GetTensorData<T>(input);
- T* out_data = tflite::micro::GetTensorData<T>(output);
- for (size_t i = 0; i < num_elements; ++i) {
- out_data[i] = func(in_data[i]);
- }
- return kTfLiteOk;
- }
- inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
- float float_func(float)) {
- return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
- }
- inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
- bool bool_func(bool)) {
- return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
- }
- TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::abs);
- }
- TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::sin);
- }
- TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::cos);
- }
- TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::log);
- }
- TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, std::sqrt);
- }
- TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
- }
- TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalNumeric(context, node, [](float f) { return f * f; });
- }
- TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
- return EvalLogical(context, node, [](bool v) { return !v; });
- }
- } // namespace
- } // namespace elementwise
- TfLiteRegistration Register_ABS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::AbsEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_SIN() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::SinEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_COS() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::CosEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_LOG() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::LogEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_SQRT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::SqrtEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_RSQRT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::RsqrtEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_SQUARE() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
- /*invoke=*/elementwise::SquareEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- TfLiteRegistration Register_LOGICAL_NOT() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/
- elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
- /*invoke=*/elementwise::LogicalNotEval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- } // namespace micro
- } // namespace ops
- } // namespace tflite
|