| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "tensorflow/lite/micro/kernels/kernel_runner.h"
- namespace tflite {
- namespace micro {
- namespace {
- constexpr size_t kBufferAlignment = 16;
- } // namespace
- // TODO(b/161841696): Consider moving away from global arena buffers:
- constexpr int KernelRunner::kNumScratchBuffers_;
- constexpr int KernelRunner::kKernelRunnerBufferSize_;
- uint8_t KernelRunner::kKernelRunnerBuffer_[];
- KernelRunner::KernelRunner(const TfLiteRegistration& registration,
- TfLiteTensor* tensors, int tensors_size,
- TfLiteIntArray* inputs, TfLiteIntArray* outputs,
- void* builtin_data, ErrorReporter* error_reporter)
- : allocator_(SimpleMemoryAllocator::Create(
- error_reporter, kKernelRunnerBuffer_, kKernelRunnerBufferSize_)),
- registration_(registration),
- tensors_(tensors),
- error_reporter_(error_reporter) {
- // Prepare TfLiteContext:
- context_.impl_ = static_cast<void*>(this);
- context_.ReportError = ReportOpError;
- context_.recommended_num_threads = 1;
- context_.GetTensor = GetTensor;
- context_.GetEvalTensor = GetEvalTensor;
- context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
- context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
- context_.GetScratchBuffer = GetScratchBuffer;
- // Prepare TfLiteNode:
- node_.inputs = inputs;
- node_.outputs = outputs;
- node_.builtin_data = builtin_data;
- }
- TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data) {
- if (registration_.init) {
- node_.user_data = registration_.init(&context_, init_data, /*length=*/0);
- }
- if (registration_.prepare) {
- TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
- }
- return kTfLiteOk;
- }
- TfLiteStatus KernelRunner::Invoke() {
- if (registration_.invoke == nullptr) {
- TF_LITE_REPORT_ERROR(error_reporter_,
- "TfLiteRegistration missing invoke function pointer!");
- return kTfLiteError;
- }
- return registration_.invoke(&context_, &node_);
- }
- TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
- int tensor_index) {
- TFLITE_DCHECK(context != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- return &runner->tensors_[tensor_index];
- }
- TfLiteEvalTensor* KernelRunner::GetEvalTensor(
- const struct TfLiteContext* context, int tensor_index) {
- TFLITE_DCHECK(context != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- TfLiteEvalTensor* eval_tensor =
- reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
- sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
- TFLITE_DCHECK(eval_tensor != nullptr);
- // In unit tests, the TfLiteTensor pointer contains the source of truth for
- // buffers and values:
- eval_tensor->data = runner->tensors_[tensor_index].data;
- eval_tensor->dims = runner->tensors_[tensor_index].dims;
- eval_tensor->type = runner->tensors_[tensor_index].type;
- return eval_tensor;
- }
- void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
- size_t bytes) {
- TFLITE_DCHECK(context != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- return runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
- }
- TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
- size_t bytes,
- int* buffer_index) {
- TFLITE_DCHECK(context != nullptr);
- TFLITE_DCHECK(buffer_index != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
- TF_LITE_REPORT_ERROR(
- runner->error_reporter_,
- "Exceeded the maximum number of scratch tensors allowed (%d).",
- kNumScratchBuffers_);
- return kTfLiteError;
- }
- // For tests, we allocate scratch buffers from the tail and keep them around
- // for the lifetime of model. This means that the arena size in the tests will
- // be more than what we would have if the scratch buffers could share memory.
- runner->scratch_buffers_[runner->scratch_buffer_count_] =
- runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
- TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
- nullptr);
- *buffer_index = runner->scratch_buffer_count_++;
- return kTfLiteOk;
- }
- void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
- TFLITE_DCHECK(context != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
- if (buffer_index >= runner->scratch_buffer_count_) {
- return nullptr;
- }
- return runner->scratch_buffers_[buffer_index];
- }
- void KernelRunner::ReportOpError(struct TfLiteContext* context,
- const char* format, ...) {
- TFLITE_DCHECK(context != nullptr);
- KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
- TFLITE_DCHECK(runner != nullptr);
- va_list args;
- va_start(args, format);
- TF_LITE_REPORT_ERROR(runner->error_reporter_, format, args);
- va_end(args);
- }
- } // namespace micro
- } // namespace tflite
|