| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
- #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
- #include <cstdio>
- #include <cstring>
- #include "packages/TensorflowLiteMicro/tensorflow/lite/c/common.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/core/api/error_reporter.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/core/api/flatbuffer_conversions.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/kernels/internal/compatibility.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/kernels/op_macros.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/compatibility.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/kernels/micro_ops.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/micro_op_resolver.h"
- #include "packages/TensorflowLiteMicro/tensorflow/lite/schema/schema_generated.h"
- namespace tflite {
- template <unsigned int tOpCount>
- class MicroMutableOpResolver : public MicroOpResolver {
- public:
- explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
- : error_reporter_(error_reporter) {}
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
- if (op == BuiltinOperator_CUSTOM) return nullptr;
- for (unsigned int i = 0; i < registrations_len_; ++i) {
- const TfLiteRegistration& registration = registrations_[i];
- if (registration.builtin_code == op) {
- return ®istration;
- }
- }
- return nullptr;
- }
- const TfLiteRegistration* FindOp(const char* op) const override {
- for (unsigned int i = 0; i < registrations_len_; ++i) {
- const TfLiteRegistration& registration = registrations_[i];
- if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
- (strcmp(registration.custom_name, op) == 0)) {
- return ®istration;
- }
- }
- return nullptr;
- }
- MicroOpResolver::BuiltinParseFunction GetOpDataParser(
- BuiltinOperator op) const override {
- TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
- for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
- if (builtin_codes_[i] == op) return builtin_parsers_[i];
- }
- return nullptr;
- }
- // Registers a Custom Operator with the MicroOpResolver.
- //
- // Only the first call for a given name will be successful. i.e. if this
- // function is called again for a previously added Custom Operator, the
- // MicroOpResolver will be unchanged and this function will return
- // kTfLiteError.
- TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) {
- if (registrations_len_ >= tOpCount) {
- if (error_reporter_) {
- TF_LITE_REPORT_ERROR(
- error_reporter_,
- "Couldn't register custom op '%s', resolver size is too small (%d)",
- name, tOpCount);
- }
- return kTfLiteError;
- }
- if (FindOp(name) != nullptr) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Calling AddCustom for the same op more than once "
- "is not supported (Op: %s).",
- name);
- }
- return kTfLiteError;
- }
- TfLiteRegistration* new_registration = ®istrations_[registrations_len_];
- registrations_len_ += 1;
- *new_registration = *registration;
- new_registration->builtin_code = BuiltinOperator_CUSTOM;
- new_registration->custom_name = name;
- return kTfLiteOk;
- }
- // Registers a Builtin Operator with the MicroOpResolver.
- //
- // Only the first call for a given BuiltinOperator enum will be successful.
- // i.e. if this function is called again for a previously added
- // BuiltinOperator, the MicroOpResolver will be unchanged and this function
- // will return kTfLiteError.
- //
- // TODO(b/149408647): remove this API once the BuiltinOperator specific Add
- // functions are fully implemented.
- TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
- TfLiteRegistration* registration) {
- TFLITE_DCHECK(registration != nullptr);
- // For code that is not switched over to the new selective registration of
- // the parse function, we pass in ParseOpData. This allows for backwards
- // compatibility.
- return AddBuiltin(op, *registration, ParseOpData);
- }
- // The Add* functions below add the various Builtin operators to the
- // MicroMutableOpResolver object.
- //
- // This API is currently experimental (and only supported for a small subset
- // of operators). It will soon be preferred over the AddBuiltin function for
- // the following reason:
- // * If all calls to AddBuiltin for an application use this API, the code
- // size will be smaller by 5-8K (compared to the using the AddBuiltin
- // override).
- TfLiteStatus AddConv2D() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function once cl/313453102 lands.
- return AddBuiltin(BuiltinOperator_CONV_2D,
- *tflite::ops::micro::Register_CONV_2D(), ParseOpData);
- }
- TfLiteStatus AddDequantize() {
- return AddBuiltin(BuiltinOperator_DEQUANTIZE,
- *tflite::ops::micro::Register_DEQUANTIZE(),
- ParseDequantize);
- }
- TfLiteStatus AddFullyConnected() {
- return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
- *tflite::ops::micro::Register_FULLY_CONNECTED(),
- ParseFullyConnected);
- }
- TfLiteStatus AddLogistic() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function once cl/313453102 lands.
- return AddBuiltin(BuiltinOperator_LOGISTIC,
- *tflite::ops::micro::Register_LOGISTIC(), ParseOpData);
- }
- TfLiteStatus AddQuantize() {
- return AddBuiltin(BuiltinOperator_QUANTIZE,
- *tflite::ops::micro::Register_QUANTIZE(), ParseQuantize);
- }
- TfLiteStatus AddReshape() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function once cl/313453102 lands.
- return AddBuiltin(BuiltinOperator_RESHAPE,
- *tflite::ops::micro::Register_RESHAPE(), ParseOpData);
- }
- TfLiteStatus AddSoftmax() {
- return AddBuiltin(BuiltinOperator_SOFTMAX,
- *tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax);
- }
- TfLiteStatus AddSvdf() {
- return AddBuiltin(BuiltinOperator_SVDF,
- *tflite::ops::micro::Register_SVDF(), ParseSvdf);
- }
- unsigned int GetRegistrationLength() { return registrations_len_; }
- private:
- TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
- const TfLiteRegistration& registration,
- MicroOpResolver::BuiltinParseFunction parser) {
- if (op == BuiltinOperator_CUSTOM) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Invalid parameter BuiltinOperator_CUSTOM to the "
- "AddBuiltin function.");
- }
- return kTfLiteError;
- }
- if (FindOp(op) != nullptr) {
- if (error_reporter_ != nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Calling AddBuiltin with the same op more than "
- "once is not supported (Op: #%d).",
- op);
- }
- return kTfLiteError;
- }
- if (registrations_len_ >= tOpCount) {
- if (error_reporter_) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "Couldn't register builtin op #%d, resolver size "
- "is too small (%d).",
- op, tOpCount);
- }
- return kTfLiteError;
- }
- registrations_[registrations_len_] = registration;
- // Strictly speaking, the builtin_code is not necessary for TFLM but filling
- // it in regardless.
- registrations_[registrations_len_].builtin_code = op;
- registrations_len_++;
- builtin_codes_[num_buitin_ops_] = op;
- builtin_parsers_[num_buitin_ops_] = parser;
- num_buitin_ops_++;
- return kTfLiteOk;
- }
- TfLiteRegistration registrations_[tOpCount];
- unsigned int registrations_len_ = 0;
- // Arrays (and counter) to store the builtin codes and their corresponding
- // parse functions as these are registered with the Op Resolver.
- BuiltinOperator builtin_codes_[tOpCount];
- MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
- unsigned int num_buitin_ops_ = 0;
- ErrorReporter* error_reporter_;
- TF_LITE_REMOVE_VIRTUAL_DELETE
- };
- }; // namespace tflite
- #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
|