micro_mutable_op_resolver.h 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
  14. #include <cstdio>
  15. #include <cstring>
  16. #include "packages/TensorflowLiteMicro/tensorflow/lite/c/common.h"
  17. #include "packages/TensorflowLiteMicro/tensorflow/lite/core/api/error_reporter.h"
  18. #include "packages/TensorflowLiteMicro/tensorflow/lite/core/api/flatbuffer_conversions.h"
  19. #include "packages/TensorflowLiteMicro/tensorflow/lite/kernels/internal/compatibility.h"
  20. #include "packages/TensorflowLiteMicro/tensorflow/lite/kernels/op_macros.h"
  21. #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/compatibility.h"
  22. #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/kernels/micro_ops.h"
  23. #include "packages/TensorflowLiteMicro/tensorflow/lite/micro/micro_op_resolver.h"
  24. #include "packages/TensorflowLiteMicro/tensorflow/lite/schema/schema_generated.h"
  25. namespace tflite {
  26. template <unsigned int tOpCount>
  27. class MicroMutableOpResolver : public MicroOpResolver {
  28. public:
  29. explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
  30. : error_reporter_(error_reporter) {}
  31. const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
  32. if (op == BuiltinOperator_CUSTOM) return nullptr;
  33. for (unsigned int i = 0; i < registrations_len_; ++i) {
  34. const TfLiteRegistration& registration = registrations_[i];
  35. if (registration.builtin_code == op) {
  36. return &registration;
  37. }
  38. }
  39. return nullptr;
  40. }
  41. const TfLiteRegistration* FindOp(const char* op) const override {
  42. for (unsigned int i = 0; i < registrations_len_; ++i) {
  43. const TfLiteRegistration& registration = registrations_[i];
  44. if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
  45. (strcmp(registration.custom_name, op) == 0)) {
  46. return &registration;
  47. }
  48. }
  49. return nullptr;
  50. }
  51. MicroOpResolver::BuiltinParseFunction GetOpDataParser(
  52. BuiltinOperator op) const override {
  53. TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
  54. for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
  55. if (builtin_codes_[i] == op) return builtin_parsers_[i];
  56. }
  57. return nullptr;
  58. }
  59. // Registers a Custom Operator with the MicroOpResolver.
  60. //
  61. // Only the first call for a given name will be successful. i.e. if this
  62. // function is called again for a previously added Custom Operator, the
  63. // MicroOpResolver will be unchanged and this function will return
  64. // kTfLiteError.
  65. TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) {
  66. if (registrations_len_ >= tOpCount) {
  67. if (error_reporter_) {
  68. TF_LITE_REPORT_ERROR(
  69. error_reporter_,
  70. "Couldn't register custom op '%s', resolver size is too small (%d)",
  71. name, tOpCount);
  72. }
  73. return kTfLiteError;
  74. }
  75. if (FindOp(name) != nullptr) {
  76. if (error_reporter_ != nullptr) {
  77. TF_LITE_REPORT_ERROR(error_reporter_,
  78. "Calling AddCustom for the same op more than once "
  79. "is not supported (Op: %s).",
  80. name);
  81. }
  82. return kTfLiteError;
  83. }
  84. TfLiteRegistration* new_registration = &registrations_[registrations_len_];
  85. registrations_len_ += 1;
  86. *new_registration = *registration;
  87. new_registration->builtin_code = BuiltinOperator_CUSTOM;
  88. new_registration->custom_name = name;
  89. return kTfLiteOk;
  90. }
  91. // Registers a Builtin Operator with the MicroOpResolver.
  92. //
  93. // Only the first call for a given BuiltinOperator enum will be successful.
  94. // i.e. if this function is called again for a previously added
  95. // BuiltinOperator, the MicroOpResolver will be unchanged and this function
  96. // will return kTfLiteError.
  97. //
  98. // TODO(b/149408647): remove this API once the BuiltinOperator specific Add
  99. // functions are fully implemented.
  100. TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
  101. TfLiteRegistration* registration) {
  102. TFLITE_DCHECK(registration != nullptr);
  103. // For code that is not switched over to the new selective registration of
  104. // the parse function, we pass in ParseOpData. This allows for backwards
  105. // compatibility.
  106. return AddBuiltin(op, *registration, ParseOpData);
  107. }
  108. // The Add* functions below add the various Builtin operators to the
  109. // MicroMutableOpResolver object.
  110. //
  111. // This API is currently experimental (and only supported for a small subset
  112. // of operators). It will soon be preferred over the AddBuiltin function for
  113. // the following reason:
  114. // * If all calls to AddBuiltin for an application use this API, the code
  115. // size will be smaller by 5-8K (compared to the using the AddBuiltin
  116. // override).
  117. TfLiteStatus AddConv2D() {
  118. // TODO(b/149408647): Replace ParseOpData with the operator specific parse
  119. // function once cl/313453102 lands.
  120. return AddBuiltin(BuiltinOperator_CONV_2D,
  121. *tflite::ops::micro::Register_CONV_2D(), ParseOpData);
  122. }
  123. TfLiteStatus AddDequantize() {
  124. return AddBuiltin(BuiltinOperator_DEQUANTIZE,
  125. *tflite::ops::micro::Register_DEQUANTIZE(),
  126. ParseDequantize);
  127. }
  128. TfLiteStatus AddFullyConnected() {
  129. return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
  130. *tflite::ops::micro::Register_FULLY_CONNECTED(),
  131. ParseFullyConnected);
  132. }
  133. TfLiteStatus AddLogistic() {
  134. // TODO(b/149408647): Replace ParseOpData with the operator specific parse
  135. // function once cl/313453102 lands.
  136. return AddBuiltin(BuiltinOperator_LOGISTIC,
  137. *tflite::ops::micro::Register_LOGISTIC(), ParseOpData);
  138. }
  139. TfLiteStatus AddQuantize() {
  140. return AddBuiltin(BuiltinOperator_QUANTIZE,
  141. *tflite::ops::micro::Register_QUANTIZE(), ParseQuantize);
  142. }
  143. TfLiteStatus AddReshape() {
  144. // TODO(b/149408647): Replace ParseOpData with the operator specific parse
  145. // function once cl/313453102 lands.
  146. return AddBuiltin(BuiltinOperator_RESHAPE,
  147. *tflite::ops::micro::Register_RESHAPE(), ParseOpData);
  148. }
  149. TfLiteStatus AddSoftmax() {
  150. return AddBuiltin(BuiltinOperator_SOFTMAX,
  151. *tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax);
  152. }
  153. TfLiteStatus AddSvdf() {
  154. return AddBuiltin(BuiltinOperator_SVDF,
  155. *tflite::ops::micro::Register_SVDF(), ParseSvdf);
  156. }
  157. unsigned int GetRegistrationLength() { return registrations_len_; }
  158. private:
  159. TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
  160. const TfLiteRegistration& registration,
  161. MicroOpResolver::BuiltinParseFunction parser) {
  162. if (op == BuiltinOperator_CUSTOM) {
  163. if (error_reporter_ != nullptr) {
  164. TF_LITE_REPORT_ERROR(error_reporter_,
  165. "Invalid parameter BuiltinOperator_CUSTOM to the "
  166. "AddBuiltin function.");
  167. }
  168. return kTfLiteError;
  169. }
  170. if (FindOp(op) != nullptr) {
  171. if (error_reporter_ != nullptr) {
  172. TF_LITE_REPORT_ERROR(error_reporter_,
  173. "Calling AddBuiltin with the same op more than "
  174. "once is not supported (Op: #%d).",
  175. op);
  176. }
  177. return kTfLiteError;
  178. }
  179. if (registrations_len_ >= tOpCount) {
  180. if (error_reporter_) {
  181. TF_LITE_REPORT_ERROR(error_reporter_,
  182. "Couldn't register builtin op #%d, resolver size "
  183. "is too small (%d).",
  184. op, tOpCount);
  185. }
  186. return kTfLiteError;
  187. }
  188. registrations_[registrations_len_] = registration;
  189. // Strictly speaking, the builtin_code is not necessary for TFLM but filling
  190. // it in regardless.
  191. registrations_[registrations_len_].builtin_code = op;
  192. registrations_len_++;
  193. builtin_codes_[num_buitin_ops_] = op;
  194. builtin_parsers_[num_buitin_ops_] = parser;
  195. num_buitin_ops_++;
  196. return kTfLiteOk;
  197. }
  198. TfLiteRegistration registrations_[tOpCount];
  199. unsigned int registrations_len_ = 0;
  200. // Arrays (and counter) to store the builtin codes and their corresponding
  201. // parse functions as these are registered with the Op Resolver.
  202. BuiltinOperator builtin_codes_[tOpCount];
  203. MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
  204. unsigned int num_buitin_ops_ = 0;
  205. ErrorReporter* error_reporter_;
  206. TF_LITE_REMOVE_VIRTUAL_DELETE
  207. };
  208. }; // namespace tflite
  209. #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_