||
- /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
- #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
- #include <cstdio>
- #include <cstring>
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/error_reporter.h"
- #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
- #include "tensorflow/lite/kernels/internal/compatibility.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- #include "tensorflow/lite/micro/compatibility.h"
- #include "tensorflow/lite/micro/kernels/micro_ops.h"
- #include "tensorflow/lite/micro/micro_op_resolver.h"
- #include "tensorflow/lite/schema/schema_generated.h"
- namespace tflite {
- template <unsigned int tOpCount>
- class MicroMutableOpResolver : public MicroOpResolver {
- public:
- explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
- : error_reporter_(error_reporter) {}
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
- if (op == BuiltinOperator_CUSTOM) return nullptr;
- for (unsigned int i = 0; i < registrations_len_; ++i) {
- const TfLiteRegistration& registration = registrations_[i];
- if (registration.builtin_code == op) {
- return ®istration;
- }
- }
- return nullptr;
- }
- const TfLiteRegistration* FindOp(const char* op) const override {
- for (unsigned int i = 0; i < registrations_len_; ++i) {
- const TfLiteRegistration& registration = registrations_[i];
- if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
- (strcmp(registration.custom_name, op) == 0)) {
- return ®istration;
- }
- }
- return nullptr;
- }
- MicroOpResolver::BuiltinParseFunction GetOpDataParser(
- BuiltinOperator op) const override {
- TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
- for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
- if (builtin_codes_[i] == op) return builtin_parsers_[i];
- }
- return nullptr;
- }
- // Registers a Custom Operator with the MicroOpResolver.
- //
- // Only the first call for a given name will be successful. i.e. if this
- // function is called again for a previously added Custom Operator, the
- // MicroOpResolver will be unchanged and this function will return
- // kTfLiteError.
- TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) {
- if (registrations_len_ >= tOpCount) {
- if (error_reporter_) {
- TF_LITE_REPORT_ERROR(
- error_reporter_,
- "Couldn't register custom op '%s', resolver size is too small (%d)",
- name, tOpCount);
- }
- return kTfLiteError;
- }
- if (FindOp(name) != nullptr) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Calling AddCustom for the same op more than once "
- "is not supported (Op: %s).",
- name);
- }
- return kTfLiteError;
- }
- TfLiteRegistration* new_registration = ®istrations_[registrations_len_];
- registrations_len_ += 1;
- *new_registration = *registration;
- new_registration->builtin_code = BuiltinOperator_CUSTOM;
- new_registration->custom_name = name;
- return kTfLiteOk;
- }
- // The Add* functions below add the various Builtin operators to the
- // MicroMutableOpResolver object.
- TfLiteStatus AddAbs() {
- return AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS(),
- ParseAbs);
- }
- TfLiteStatus AddAdd() {
- return AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD(),
- ParseAdd);
- }
- TfLiteStatus AddArgMax() {
- return AddBuiltin(BuiltinOperator_ARG_MAX,
- tflite::ops::micro::Register_ARG_MAX(), ParseArgMax);
- }
- TfLiteStatus AddArgMin() {
- return AddBuiltin(BuiltinOperator_ARG_MIN,
- tflite::ops::micro::Register_ARG_MIN(), ParseArgMin);
- }
- TfLiteStatus AddAveragePool2D() {
- return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
- tflite::ops::micro::Register_AVERAGE_POOL_2D(),
- ParsePool);
- }
- TfLiteStatus AddCeil() {
- return AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL(),
- ParseCeil);
- }
- TfLiteStatus AddCircularBuffer() {
- return AddCustom("CIRCULAR_BUFFER",
- tflite::ops::micro::Register_CIRCULAR_BUFFER());
- }
- TfLiteStatus AddConcatenation() {
- return AddBuiltin(BuiltinOperator_CONCATENATION,
- tflite::ops::micro::Register_CONCATENATION(),
- ParseConcatenation);
- }
- TfLiteStatus AddConv2D() {
- return AddBuiltin(BuiltinOperator_CONV_2D,
- tflite::ops::micro::Register_CONV_2D(), ParseConv2D);
- }
- TfLiteStatus AddCos() {
- return AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS(),
- ParseCos);
- }
- TfLiteStatus AddDepthwiseConv2D() {
- return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
- tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
- ParseDepthwiseConv2D);
- }
- TfLiteStatus AddDequantize() {
- return AddBuiltin(BuiltinOperator_DEQUANTIZE,
- tflite::ops::micro::Register_DEQUANTIZE(),
- ParseDequantize);
- }
- TfLiteStatus AddEqual() {
- return AddBuiltin(BuiltinOperator_EQUAL,
- tflite::ops::micro::Register_EQUAL(), ParseEqual);
- }
- TfLiteStatus AddFloor() {
- return AddBuiltin(BuiltinOperator_FLOOR,
- tflite::ops::micro::Register_FLOOR(), ParseFloor);
- }
- TfLiteStatus AddFullyConnected() {
- return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
- tflite::ops::micro::Register_FULLY_CONNECTED(),
- ParseFullyConnected);
- }
- TfLiteStatus AddGreater() {
- return AddBuiltin(BuiltinOperator_GREATER,
- tflite::ops::micro::Register_GREATER(), ParseGreater);
- }
- TfLiteStatus AddGreaterEqual() {
- return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
- tflite::ops::micro::Register_GREATER_EQUAL(),
- ParseGreaterEqual);
- }
- TfLiteStatus AddHardSwish() {
- return AddBuiltin(BuiltinOperator_HARD_SWISH,
- tflite::ops::micro::Register_HARD_SWISH(),
- ParseHardSwish);
- }
- TfLiteStatus AddL2Normalization() {
- return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
- tflite::ops::micro::Register_L2_NORMALIZATION(),
- ParseL2Normalization);
- }
- TfLiteStatus AddLess() {
- return AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS(),
- ParseLess);
- }
- TfLiteStatus AddLessEqual() {
- return AddBuiltin(BuiltinOperator_LESS_EQUAL,
- tflite::ops::micro::Register_LESS_EQUAL(),
- ParseLessEqual);
- }
- TfLiteStatus AddLog() {
- return AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG(),
- ParseLog);
- }
- TfLiteStatus AddLogicalAnd() {
- return AddBuiltin(BuiltinOperator_LOGICAL_AND,
- tflite::ops::micro::Register_LOGICAL_AND(),
- ParseLogicalAnd);
- }
- TfLiteStatus AddLogicalNot() {
- return AddBuiltin(BuiltinOperator_LOGICAL_NOT,
- tflite::ops::micro::Register_LOGICAL_NOT(),
- ParseLogicalNot);
- }
- TfLiteStatus AddLogicalOr() {
- return AddBuiltin(BuiltinOperator_LOGICAL_OR,
- tflite::ops::micro::Register_LOGICAL_OR(),
- ParseLogicalOr);
- }
- TfLiteStatus AddLogistic() {
- return AddBuiltin(BuiltinOperator_LOGISTIC,
- tflite::ops::micro::Register_LOGISTIC(), ParseLogistic);
- }
- TfLiteStatus AddMaximum() {
- return AddBuiltin(BuiltinOperator_MAXIMUM,
- tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
- }
- TfLiteStatus AddMaxPool2D() {
- return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
- tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
- }
- TfLiteStatus AddMean() {
- return AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN(),
- ParseReducer);
- }
- TfLiteStatus AddMinimum() {
- return AddBuiltin(BuiltinOperator_MINIMUM,
- tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
- }
- TfLiteStatus AddMul() {
- return AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL(),
- ParseMul);
- }
- TfLiteStatus AddNeg() {
- return AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG(),
- ParseNeg);
- }
- TfLiteStatus AddNotEqual() {
- return AddBuiltin(BuiltinOperator_NOT_EQUAL,
- tflite::ops::micro::Register_NOT_EQUAL(), ParseNotEqual);
- }
- TfLiteStatus AddPack() {
- return AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK(),
- ParsePack);
- }
- TfLiteStatus AddPad() {
- return AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD(),
- ParsePad);
- }
- TfLiteStatus AddPadV2() {
- return AddBuiltin(BuiltinOperator_PADV2,
- tflite::ops::micro::Register_PADV2(), ParsePadV2);
- }
- TfLiteStatus AddPrelu() {
- return AddBuiltin(BuiltinOperator_PRELU,
- tflite::ops::micro::Register_PRELU(), ParsePrelu);
- }
- TfLiteStatus AddQuantize() {
- return AddBuiltin(BuiltinOperator_QUANTIZE,
- tflite::ops::micro::Register_QUANTIZE(), ParseQuantize);
- }
- TfLiteStatus AddRelu() {
- return AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU(),
- ParseRelu);
- }
- TfLiteStatus AddRelu6() {
- return AddBuiltin(BuiltinOperator_RELU6,
- tflite::ops::micro::Register_RELU6(), ParseRelu6);
- }
- TfLiteStatus AddReshape() {
- return AddBuiltin(BuiltinOperator_RESHAPE,
- tflite::ops::micro::Register_RESHAPE(), ParseReshape);
- }
- TfLiteStatus AddResizeNearestNeighbor() {
- return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
- tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(),
- ParseResizeNearestNeighbor);
- }
- TfLiteStatus AddRound() {
- return AddBuiltin(BuiltinOperator_ROUND,
- tflite::ops::micro::Register_ROUND(), ParseRound);
- }
- TfLiteStatus AddRsqrt() {
- return AddBuiltin(BuiltinOperator_RSQRT,
- tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
- }
- TfLiteStatus AddSin() {
- return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(),
- ParseSin);
- }
- TfLiteStatus AddSoftmax() {
- return AddBuiltin(BuiltinOperator_SOFTMAX,
- tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax);
- }
- TfLiteStatus AddSplit() {
- return AddBuiltin(BuiltinOperator_SPLIT,
- tflite::ops::micro::Register_SPLIT(), ParseSplit);
- }
- TfLiteStatus AddSqrt() {
- return AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT(),
- ParseSqrt);
- }
- TfLiteStatus AddSquare() {
- return AddBuiltin(BuiltinOperator_SQUARE,
- tflite::ops::micro::Register_SQUARE(), ParseSquare);
- }
- TfLiteStatus AddStridedSlice() {
- return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
- tflite::ops::micro::Register_STRIDED_SLICE(),
- ParseStridedSlice);
- }
- TfLiteStatus AddSub() {
- return AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB(),
- ParseSub);
- }
- TfLiteStatus AddSvdf() {
- return AddBuiltin(BuiltinOperator_SVDF, tflite::ops::micro::Register_SVDF(),
- ParseSvdf);
- }
- TfLiteStatus AddTanh() {
- return AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH(),
- ParseTanh);
- }
- TfLiteStatus AddUnpack() {
- return AddBuiltin(BuiltinOperator_UNPACK,
- tflite::ops::micro::Register_UNPACK(), ParseUnpack);
- }
- unsigned int GetRegistrationLength() { return registrations_len_; }
- private:
- TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
- const TfLiteRegistration& registration,
- MicroOpResolver::BuiltinParseFunction parser) {
- if (op == BuiltinOperator_CUSTOM) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Invalid parameter BuiltinOperator_CUSTOM to the "
- "AddBuiltin function.");
- }
- return kTfLiteError;
- }
- if (FindOp(op) != nullptr) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Calling AddBuiltin with the same op more than "
- "once is not supported (Op: #%d).",
- op);
- }
- return kTfLiteError;
- }
- if (registrations_len_ >= tOpCount) {
- if (error_reporter_) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Couldn't register builtin op #%d, resolver size "
- "is too small (%d).",
- op, tOpCount);
- }
- return kTfLiteError;
- }
- registrations_[registrations_len_] = registration;
- // Strictly speaking, the builtin_code is not necessary for TFLM but filling
- // it in regardless.
- registrations_[registrations_len_].builtin_code = op;
- registrations_len_++;
- builtin_codes_[num_buitin_ops_] = op;
- builtin_parsers_[num_buitin_ops_] = parser;
- num_buitin_ops_++;
- return kTfLiteOk;
- }
- TfLiteRegistration registrations_[tOpCount];
- unsigned int registrations_len_ = 0;
- // Arrays (and counter) to store the builtin codes and their corresponding
- // parse functions as these are registered with the Op Resolver.
- BuiltinOperator builtin_codes_[tOpCount];
- MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
- unsigned int num_buitin_ops_ = 0;
- ErrorReporter* error_reporter_;
- TF_LITE_REMOVE_VIRTUAL_DELETE
- };
- }; // namespace tflite
- #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
|