recognize_commands.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
  13. #define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
  14. #include <cstdint>
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h"
  17. #include "tensorflow/lite/micro/micro_error_reporter.h"
  18. // Partial implementation of std::dequeue, just providing the functionality
  19. // that's needed to keep a record of previous neural network results over a
  20. // short time period, so they can be averaged together to produce a more
  21. // accurate overall prediction. This doesn't use any dynamic memory allocation
  22. // so it's a better fit for microcontroller applications, but this does mean
  23. // there are hard limits on the number of results it can store.
  24. class PreviousResultsQueue {
  25. public:
  26. PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
  27. : error_reporter_(error_reporter), front_index_(0), size_(0) {}
  28. // Data structure that holds an inference result, and the time when it
  29. // was recorded.
  30. struct Result {
  31. Result() : time_(0), scores() {}
  32. Result(int32_t time, int8_t* input_scores) : time_(time) {
  33. for (int i = 0; i < kCategoryCount; ++i) {
  34. scores[i] = input_scores[i];
  35. }
  36. }
  37. int32_t time_;
  38. int8_t scores[kCategoryCount];
  39. };
  40. int size() { return size_; }
  41. bool empty() { return size_ == 0; }
  42. Result& front() { return results_[front_index_]; }
  43. Result& back() {
  44. int back_index = front_index_ + (size_ - 1);
  45. if (back_index >= kMaxResults) {
  46. back_index -= kMaxResults;
  47. }
  48. return results_[back_index];
  49. }
  50. void push_back(const Result& entry) {
  51. if (size() >= kMaxResults) {
  52. TF_LITE_REPORT_ERROR(
  53. error_reporter_,
  54. "Couldn't push_back latest result, too many already!");
  55. return;
  56. }
  57. size_ += 1;
  58. back() = entry;
  59. }
  60. Result pop_front() {
  61. if (size() <= 0) {
  62. TF_LITE_REPORT_ERROR(error_reporter_,
  63. "Couldn't pop_front result, none present!");
  64. return Result();
  65. }
  66. Result result = front();
  67. front_index_ += 1;
  68. if (front_index_ >= kMaxResults) {
  69. front_index_ = 0;
  70. }
  71. size_ -= 1;
  72. return result;
  73. }
  74. // Most of the functions are duplicates of dequeue containers, but this
  75. // is a helper that makes it easy to iterate through the contents of the
  76. // queue.
  77. Result& from_front(int offset) {
  78. if ((offset < 0) || (offset >= size_)) {
  79. TF_LITE_REPORT_ERROR(error_reporter_,
  80. "Attempt to read beyond the end of the queue!");
  81. offset = size_ - 1;
  82. }
  83. int index = front_index_ + offset;
  84. if (index >= kMaxResults) {
  85. index -= kMaxResults;
  86. }
  87. return results_[index];
  88. }
  89. private:
  90. tflite::ErrorReporter* error_reporter_;
  91. static constexpr int kMaxResults = 50;
  92. Result results_[kMaxResults];
  93. int front_index_;
  94. int size_;
  95. };
  96. // This class is designed to apply a very primitive decoding model on top of the
  97. // instantaneous results from running an audio recognition model on a single
  98. // window of samples. It applies smoothing over time so that noisy individual
  99. // label scores are averaged, increasing the confidence that apparent matches
  100. // are real.
  101. // To use it, you should create a class object with the configuration you
  102. // want, and then feed results from running a TensorFlow model into the
  103. // processing method. The timestamp for each subsequent call should be
  104. // increasing from the previous, since the class is designed to process a stream
  105. // of data over time.
  106. class RecognizeCommands {
  107. public:
  108. // labels should be a list of the strings associated with each one-hot score.
  109. // The window duration controls the smoothing. Longer durations will give a
  110. // higher confidence that the results are correct, but may miss some commands.
  111. // The detection threshold has a similar effect, with high values increasing
  112. // the precision at the cost of recall. The minimum count controls how many
  113. // results need to be in the averaging window before it's seen as a reliable
  114. // average. This prevents erroneous results when the averaging window is
  115. // initially being populated for example. The suppression argument disables
  116. // further recognitions for a set time after one has been triggered, which can
  117. // help reduce spurious recognitions.
  118. explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
  119. int32_t average_window_duration_ms = 1000,
  120. uint8_t detection_threshold = 200,
  121. int32_t suppression_ms = 1500,
  122. int32_t minimum_count = 3);
  123. // Call this with the results of running a model on sample data.
  124. TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
  125. const int32_t current_time_ms,
  126. const char** found_command, uint8_t* score,
  127. bool* is_new_command);
  128. private:
  129. // Configuration
  130. tflite::ErrorReporter* error_reporter_;
  131. int32_t average_window_duration_ms_;
  132. uint8_t detection_threshold_;
  133. int32_t suppression_ms_;
  134. int32_t minimum_count_;
  135. // Working variables
  136. PreviousResultsQueue previous_results_;
  137. const char* previous_top_label_;
  138. int32_t previous_top_label_time_;
  139. };
  140. #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_