non_max_suppression.h 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_NON_MAX_SUPPRESSION_H_
  13. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_NON_MAX_SUPPRESSION_H_
  14. #include <algorithm>
  15. #include <cmath>
  16. #include <numeric>
  17. #include <queue>
  18. namespace tflite {
  19. namespace reference_ops {
  20. // A pair of diagonal corners of the box.
  21. struct BoxCornerEncoding {
  22. float y1;
  23. float x1;
  24. float y2;
  25. float x2;
  26. };
  27. inline float ComputeIntersectionOverUnion(const float* boxes, const int i,
  28. const int j) {
  29. auto& box_i = reinterpret_cast<const BoxCornerEncoding*>(boxes)[i];
  30. auto& box_j = reinterpret_cast<const BoxCornerEncoding*>(boxes)[j];
  31. const float box_i_y_min = std::min<float>(box_i.y1, box_i.y2);
  32. const float box_i_y_max = std::max<float>(box_i.y1, box_i.y2);
  33. const float box_i_x_min = std::min<float>(box_i.x1, box_i.x2);
  34. const float box_i_x_max = std::max<float>(box_i.x1, box_i.x2);
  35. const float box_j_y_min = std::min<float>(box_j.y1, box_j.y2);
  36. const float box_j_y_max = std::max<float>(box_j.y1, box_j.y2);
  37. const float box_j_x_min = std::min<float>(box_j.x1, box_j.x2);
  38. const float box_j_x_max = std::max<float>(box_j.x1, box_j.x2);
  39. const float area_i =
  40. (box_i_y_max - box_i_y_min) * (box_i_x_max - box_i_x_min);
  41. const float area_j =
  42. (box_j_y_max - box_j_y_min) * (box_j_x_max - box_j_x_min);
  43. if (area_i <= 0 || area_j <= 0) return 0.0;
  44. const float intersection_ymax = std::min<float>(box_i_y_max, box_j_y_max);
  45. const float intersection_xmax = std::min<float>(box_i_x_max, box_j_x_max);
  46. const float intersection_ymin = std::max<float>(box_i_y_min, box_j_y_min);
  47. const float intersection_xmin = std::max<float>(box_i_x_min, box_j_x_min);
  48. const float intersection_area =
  49. std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
  50. std::max<float>(intersection_xmax - intersection_xmin, 0.0);
  51. return intersection_area / (area_i + area_j - intersection_area);
  52. }
  53. // Implements (Single-Class) Soft NMS (with Gaussian weighting).
  54. // Supports functionality of TensorFlow ops NonMaxSuppressionV4 & V5.
  55. // Reference: "Soft-NMS - Improving Object Detection With One Line of Code"
  56. // [Bodla et al, https://arxiv.org/abs/1704.04503]
  57. // Implementation adapted from the TensorFlow NMS code at
  58. // tensorflow/core/kernels/non_max_suppression_op.cc.
  59. //
  60. // Arguments:
  61. // boxes: box encodings in format [y1, x1, y2, x2], shape: [num_boxes, 4]
  62. // num_boxes: number of candidates
  63. // scores: scores for candidate boxes, in the same order. shape: [num_boxes]
  64. // max_output_size: the maximum number of selections.
  65. // iou_threshold: Intersection-over-Union (IoU) threshold for NMS
  66. // score_threshold: All candidate scores below this value are rejected
  67. // soft_nms_sigma: Soft NMS parameter, used for decaying scores
  68. //
  69. // Outputs:
  70. // selected_indices: all the selected indices. Underlying array must have
  71. // length >= max_output_size. Cannot be null.
  72. // selected_scores: scores of selected indices. Defer from original value for
  73. // Soft NMS. If not null, array must have length >= max_output_size.
  74. // num_selected_indices: Number of selections. Only these many elements are
  75. // set in selected_indices, selected_scores. Cannot be null.
  76. //
  77. // Assumes inputs are valid (for eg, iou_threshold must be >= 0).
  78. inline void NonMaxSuppression(const float* boxes, const int num_boxes,
  79. const float* scores, const int max_output_size,
  80. const float iou_threshold,
  81. const float score_threshold,
  82. const float soft_nms_sigma, int* selected_indices,
  83. float* selected_scores,
  84. int* num_selected_indices) {
  85. struct Candidate {
  86. int index;
  87. float score;
  88. int suppress_begin_index;
  89. };
  90. // Priority queue to hold candidates.
  91. auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
  92. return bs_i.score < bs_j.score;
  93. };
  94. std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
  95. candidate_priority_queue(cmp);
  96. // Populate queue with candidates above the score threshold.
  97. for (int i = 0; i < num_boxes; ++i) {
  98. if (scores[i] > score_threshold) {
  99. candidate_priority_queue.emplace(Candidate({i, scores[i], 0}));
  100. }
  101. }
  102. *num_selected_indices = 0;
  103. int num_outputs = std::min(static_cast<int>(candidate_priority_queue.size()),
  104. max_output_size);
  105. if (num_outputs == 0) return;
  106. // NMS loop.
  107. float scale = 0;
  108. if (soft_nms_sigma > 0.0) {
  109. scale = -0.5 / soft_nms_sigma;
  110. }
  111. while (*num_selected_indices < num_outputs &&
  112. !candidate_priority_queue.empty()) {
  113. Candidate next_candidate = candidate_priority_queue.top();
  114. const float original_score = next_candidate.score;
  115. candidate_priority_queue.pop();
  116. // Overlapping boxes are likely to have similar scores, therefore we
  117. // iterate through the previously selected boxes backwards in order to
  118. // see if `next_candidate` should be suppressed. We also enforce a property
  119. // that a candidate can be suppressed by another candidate no more than
  120. // once via `suppress_begin_index` which tracks which previously selected
  121. // boxes have already been compared against next_candidate prior to a given
  122. // iteration. These previous selected boxes are then skipped over in the
  123. // following loop.
  124. bool should_hard_suppress = false;
  125. for (int j = *num_selected_indices - 1;
  126. j >= next_candidate.suppress_begin_index; --j) {
  127. const float iou = ComputeIntersectionOverUnion(
  128. boxes, next_candidate.index, selected_indices[j]);
  129. // First decide whether to perform hard suppression.
  130. if (iou >= iou_threshold) {
  131. should_hard_suppress = true;
  132. break;
  133. }
  134. // Suppress score if NMS sigma > 0.
  135. if (soft_nms_sigma > 0.0) {
  136. next_candidate.score =
  137. next_candidate.score * std::exp(scale * iou * iou);
  138. }
  139. // If score has fallen below score_threshold, it won't be pushed back into
  140. // the queue.
  141. if (next_candidate.score <= score_threshold) break;
  142. }
  143. // If `next_candidate.score` has not dropped below `score_threshold`
  144. // by this point, then we know that we went through all of the previous
  145. // selections and can safely update `suppress_begin_index` to
  146. // `selected.size()`. If on the other hand `next_candidate.score`
  147. // *has* dropped below the score threshold, then since `suppress_weight`
  148. // always returns values in [0, 1], further suppression by items that were
  149. // not covered in the above for loop would not have caused the algorithm
  150. // to select this item. We thus do the same update to
  151. // `suppress_begin_index`, but really, this element will not be added back
  152. // into the priority queue.
  153. next_candidate.suppress_begin_index = *num_selected_indices;
  154. if (!should_hard_suppress) {
  155. if (next_candidate.score == original_score) {
  156. // Suppression has not occurred, so select next_candidate.
  157. selected_indices[*num_selected_indices] = next_candidate.index;
  158. if (selected_scores) {
  159. selected_scores[*num_selected_indices] = next_candidate.score;
  160. }
  161. ++*num_selected_indices;
  162. }
  163. if (next_candidate.score > score_threshold) {
  164. // Soft suppression might have occurred and current score is still
  165. // greater than score_threshold; add next_candidate back onto priority
  166. // queue.
  167. candidate_priority_queue.push(next_candidate);
  168. }
  169. }
  170. }
  171. }
  172. } // namespace reference_ops
  173. } // namespace tflite
  174. #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_NON_MAX_SUPPRESSION_H_