flatbuffer_conversions.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
  13. #define TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
  14. // These functions transform codes and data structures that are defined in the
  15. // flatbuffer serialization format into in-memory values that are used by the
  16. // runtime API and interpreter.
  17. #include <cstddef>
  18. #include <new>
  19. #include <type_traits>
  20. #include "tflite/c/common.h"
  21. #include "error_reporter.h"
  22. #include "tflite/schema/schema_generated.h"
  23. namespace tflite {
  24. // Interface class for builtin data allocations.
  25. class BuiltinDataAllocator {
  26. public:
  27. virtual void* Allocate(size_t size, size_t alignment_hint) = 0;
  28. virtual void Deallocate(void* data) = 0;
  29. // Allocate a structure, but make sure it is a POD structure that doesn't
  30. // require constructors to run. The reason we do this, is that Interpreter's C
  31. // extension part will take ownership so destructors will not be run during
  32. // deallocation.
  33. template <typename T>
  34. T* AllocatePOD() {
  35. // TODO(b/154346074): Change this to is_trivially_destructible when all
  36. // platform targets support that properly.
  37. static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
  38. void* allocated_memory = this->Allocate(sizeof(T), alignof(T));
  39. return new (allocated_memory) T;
  40. }
  41. virtual ~BuiltinDataAllocator() {}
  42. };
  43. // Parse the appropriate data out of the op.
  44. //
  45. // This handles builtin data explicitly as there are flatbuffer schemas.
  46. // If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
  47. // calling function has to pass in an allocator object, and this allocator
  48. // will be called to reserve space for the output data. If the calling
  49. // function's allocator reserves memory on the heap, then it's the calling
  50. // function's responsibility to free it.
  51. // If it returns kTfLiteError, `builtin_data` will be `nullptr`.
  52. TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
  53. ErrorReporter* error_reporter,
  54. BuiltinDataAllocator* allocator, void** builtin_data);
  55. // Converts the tensor data type used in the flat buffer to the representation
  56. // used by the runtime.
  57. TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
  58. ErrorReporter* error_reporter);
  59. // TODO(b/149408647): The (unnecessary) op_type parameter in the functions below
  60. // is to keep the same signature as ParseOpData. This allows for a gradual
  61. // transfer to selective registration of the parse function, but should be
  62. // removed once we are no longer using ParseOpData for the OpResolver
  63. // implementation in micro.
  64. TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
  65. ErrorReporter* error_reporter,
  66. BuiltinDataAllocator* allocator, void** builtin_data);
  67. TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
  68. ErrorReporter* error_reporter,
  69. BuiltinDataAllocator* allocator,
  70. void** builtin_data);
  71. TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
  72. ErrorReporter* error_reporter,
  73. BuiltinDataAllocator* allocator,
  74. void** builtin_data);
  75. TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type,
  76. ErrorReporter* error_reporter,
  77. BuiltinDataAllocator* allocator,
  78. void** builtin_data);
  79. TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
  80. ErrorReporter* error_reporter,
  81. BuiltinDataAllocator* allocator,
  82. void** builtin_data);
  83. TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator op_type,
  84. ErrorReporter* error_reporter,
  85. BuiltinDataAllocator* allocator, void** builtin_data);
  86. TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator op_type,
  87. ErrorReporter* error_reporter,
  88. BuiltinDataAllocator* allocator, void** builtin_data);
  89. TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator op_type,
  90. ErrorReporter* error_reporter,
  91. BuiltinDataAllocator* allocator, void** builtin_data);
  92. } // namespace tflite
  93. #endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_