interpreter.h 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /* Copyright 2019-2020 Canaan Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #pragma once
  16. #include "model.h"
  17. #include <chrono>
  18. #include <memory>
  19. #include <optional>
  20. #include <xtl/xspan.hpp>
  21. namespace nncase
  22. {
  23. namespace runtime
  24. {
  25. class interpreter_base;
  26. typedef void (*run_callback_t)(void *userdata);
  27. typedef void (*error_callback_t)(const char *err, void *userdata);
  28. typedef void (*node_profile_callback_t)(runtime_opcode op, std::chrono::nanoseconds duration, void *userdata);
  29. typedef void (interpreter_base::*interpreter_step_t)();
  30. class interpreter_base
  31. {
  32. public:
  33. using clock_t = std::chrono::system_clock;
  34. bool try_load_model(const uint8_t *buffer);
  35. uint32_t model_size(const uint8_t *buffer);
  36. size_t inputs_size() const noexcept { return model_header_->inputs; }
  37. size_t outputs_size() const noexcept { return model_header_->outputs; }
  38. size_t nodes_size() const noexcept { return model_header_->nodes; }
  39. const runtime_shape_t &input_shape_at(size_t index) const noexcept { return input_shapes_.at(index); }
  40. const memory_range &input_at(size_t index) const noexcept { return inputs_[index]; }
  41. const memory_range &output_at(size_t index) const noexcept { return outputs_[index]; }
  42. template <class T>
  43. xtl::span<T> memory_at(const memory_range &range) const noexcept
  44. {
  45. auto span = memory_at(range);
  46. return { reinterpret_cast<T *>(span.data()), span.size() / sizeof(T) };
  47. }
  48. std::chrono::nanoseconds total_duration() const noexcept { return total_duration_; }
  49. void run(run_callback_t callback, error_callback_t on_error, node_profile_callback_t node_profile, void *userdata);
  50. protected:
  51. virtual bool initialize();
  52. virtual xtl::span<uint8_t> memory_at(const memory_range &range) const noexcept;
  53. virtual clock_t::time_point get_now() const noexcept;
  54. private:
  55. void step();
  56. private:
  57. const model_header *model_header_;
  58. std::unique_ptr<uint8_t[]> main_mem_;
  59. xtl::span<const memory_range> inputs_;
  60. xtl::span<const memory_range> outputs_;
  61. xtl::span<const runtime_shape_t> input_shapes_;
  62. xtl::span<const node_header> node_headers_;
  63. xtl::span<const uint8_t> constants_;
  64. const uint8_t *node_body_start_;
  65. error_callback_t on_error_;
  66. run_callback_t run_callback_;
  67. node_profile_callback_t node_profile_;
  68. void *userdata_;
  69. size_t cnt_node_;
  70. const uint8_t *cnt_node_body_;
  71. std::chrono::nanoseconds total_duration_;
  72. std::optional<clock_t::time_point> last_time_;
  73. runtime_opcode last_op_;
  74. };
  75. }
  76. }