process_broadcast_shapes.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
  13. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
  14. #include "tensorflow/lite/kernels/internal/types.h"
  15. namespace tflite {
  16. namespace reference_ops {
  17. // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
  18. //
  19. // For example, if sequence of dimensions of one input is
  20. // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
  21. // we can consolidate these as
  22. // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
  23. //
  24. // The category is updated in the less-frequent case of shapes that are
  25. // not suited to a fivefold-loop broadcast.
  26. //
  27. // Falls back to generic pattern when it does not know how to process properly.
  28. //
  29. // Returns true iff there is some sort of broadcast, which includes five-fold
  30. // patterns and falling back to generic broadcast.
  31. inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
  32. const RuntimeShape& shape1,
  33. tflite::ArithmeticParams* params) {
  34. const int dims_count =
  35. std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
  36. params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
  37. RuntimeShape scalar_shape(dims_count, 1);
  38. auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
  39. auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
  40. // Check for "exact" match, implicitly accepting any scalar shapes.
  41. if (extended_shape0 == extended_shape1) {
  42. params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
  43. return false;
  44. }
  45. for (int i = dims_count - 1; i >= 0; --i) {
  46. if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
  47. continue;
  48. } else if (extended_shape0.Dims(i) == 1) {
  49. params->broadcast_category =
  50. BroadcastableOpCategory::kFirstInputBroadcastsFast;
  51. break;
  52. } else if (extended_shape1.Dims(i) == 1) {
  53. params->broadcast_category =
  54. BroadcastableOpCategory::kSecondInputBroadcastsFast;
  55. break;
  56. } else {
  57. // This case is erroneous: there is a dimension that does not match and
  58. // is not a broadcast from one shape to the other.
  59. params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
  60. return true;
  61. }
  62. }
  63. if (params->broadcast_category !=
  64. BroadcastableOpCategory::kFirstInputBroadcastsFast &&
  65. params->broadcast_category !=
  66. BroadcastableOpCategory::kSecondInputBroadcastsFast) {
  67. // This is unreachable because at least one else clause in the above loop
  68. // must be reached.
  69. TFLITE_DCHECK(false);
  70. params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
  71. return false;
  72. }
  73. // From this point it is assumed contractually that corresponding dimensions
  74. // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
  75. const bool swap_inputs = params->broadcast_category ==
  76. BroadcastableOpCategory::kSecondInputBroadcastsFast;
  77. const RuntimeShape* shape_a =
  78. swap_inputs ? &extended_shape1 : &extended_shape0;
  79. const RuntimeShape* shape_b =
  80. swap_inputs ? &extended_shape0 : &extended_shape1;
  81. int i = dims_count - 1;
  82. params->broadcast_shape[0] = 1;
  83. params->broadcast_shape[1] = 1;
  84. params->broadcast_shape[2] = 1;
  85. params->broadcast_shape[3] = 1;
  86. params->broadcast_shape[4] = 1;
  87. // y_0 is greedy: include dims if both or neither equal 1: in other words,
  88. // test for equality rather than (shape_a->Dims(i) != 1).
  89. while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
  90. params->broadcast_shape[4] *= shape_b->Dims(i);
  91. --i;
  92. }
  93. // Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
  94. // that has the unit dimension, the next two loops are not entered.
  95. while (i >= 0 && shape_a->Dims(i) == 1) {
  96. params->broadcast_shape[3] *= shape_b->Dims(i);
  97. --i;
  98. }
  99. while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
  100. params->broadcast_shape[2] *= shape_a->Dims(i);
  101. --i;
  102. }
  103. // Here either input_a or input_b has dim of 1 (if i >= 0).
  104. while (i >= 0 && shape_b->Dims(i) == 1) {
  105. params->broadcast_shape[1] *= shape_a->Dims(i);
  106. --i;
  107. }
  108. while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
  109. params->broadcast_shape[0] *= shape_b->Dims(i);
  110. --i;
  111. }
  112. // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
  113. // loop.
  114. if (i >= 0) {
  115. params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
  116. }
  117. return true;
  118. }
  119. } // namespace reference_ops
  120. } // namespace tflite
  121. #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_