| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
- #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
- #include <limits>
- #include <vector>
- #include "tensorflow/lite/kernels/internal/compatibility.h"
- #include "tensorflow/lite/kernels/internal/types.h"
- namespace tflite {
- namespace strided_slice {
- // Use until std::clamp() is available from C++17.
- inline int Clamp(const int v, const int lo, const int hi) {
- TFLITE_DCHECK(!(hi < lo));
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
- }
- inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
- int dim_count) {
- // Add indices and mask bits to fully include extra dimensions
- TFLITE_CHECK_LE(dim_count, 5);
- TFLITE_CHECK_GE(dim_count, p->start_indices_count);
- TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
- TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
- const int pad_count = dim_count - p->start_indices_count;
- // Pad indices at start, so move arrays by pad_count.
- for (int i = p->start_indices_count - 1; i >= 0; --i) {
- p->strides[i + pad_count] = p->strides[i];
- p->start_indices[i + pad_count] = p->start_indices[i];
- p->stop_indices[i + pad_count] = p->stop_indices[i];
- }
- for (int i = 0; i < pad_count; ++i) {
- p->start_indices[i] = 0;
- p->stop_indices[i] = 1;
- p->strides[i] = 1;
- }
- // Pad masks with 0s or 1s as required.
- p->shrink_axis_mask <<= pad_count;
- p->ellipsis_mask <<= pad_count;
- p->new_axis_mask <<= pad_count;
- p->begin_mask <<= pad_count;
- p->end_mask <<= pad_count;
- p->begin_mask |= (1 << pad_count) - 1;
- p->end_mask |= (1 << pad_count) - 1;
- p->start_indices_count = dim_count;
- p->stop_indices_count = dim_count;
- p->strides_count = dim_count;
- }
- // Return the index for the first element along that axis. This index will be a
- // positive integer between [0, axis_size - 1] that can be used to index
- // directly into the data.
- inline int StartForAxis(const tflite::StridedSliceParams& params,
- const RuntimeShape& input_shape, int axis) {
- const auto begin_mask = params.begin_mask;
- const auto* start_indices = params.start_indices;
- const auto* strides = params.strides;
- const int axis_size = input_shape.Dims(axis);
- if (axis_size == 0) {
- return 0;
- }
- // Begin with the specified index.
- int start = start_indices[axis];
- // begin_mask override
- if (begin_mask & 1 << axis) {
- if (strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
- // Handle negative indices
- if (start < 0) {
- start += axis_size;
- }
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
- return start;
- }
- // Return the "real" index for the end of iteration along that axis. This is an
- // "end" in the traditional C sense, in that it points to one past the last
- // element. ie. So if you were iterating through all elements of a 1D array of
- // size 4, this function would return 4 as the stop, because it is one past the
- // "real" indices of 0, 1, 2 & 3.
- inline int StopForAxis(const tflite::StridedSliceParams& params,
- const RuntimeShape& input_shape, int axis,
- int start_for_axis) {
- const auto end_mask = params.end_mask;
- const auto shrink_axis_mask = params.shrink_axis_mask;
- const auto* stop_indices = params.stop_indices;
- const auto* strides = params.strides;
- const int axis_size = input_shape.Dims(axis);
- if (axis_size == 0) {
- return 0;
- }
- // Begin with the specified index
- const bool shrink_axis = shrink_axis_mask & (1 << axis);
- int stop = stop_indices[axis];
- // When shrinking an axis, the end position does not matter (and can be
- // incorrect when negative indexing is used, see Issue #19260). Always use
- // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
- // already been adjusted for negative indices.
- if (shrink_axis) {
- stop = start_for_axis + 1;
- }
- // end_mask override
- if (end_mask & (1 << axis)) {
- if (strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
- // Handle negative indices
- if (stop < 0) {
- stop += axis_size;
- }
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
- return stop;
- }
- inline bool LoopCondition(int index, int stop, int stride) {
- // True when we have reached the end of an axis and should loop.
- return stride > 0 ? index >= stop : index <= stop;
- }
- inline tflite::StridedSliceParams BuildStridedSliceParams(
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
- const std::vector<int>& strides) {
- tflite::StridedSliceParams op_params;
- const int dims_count = start_indices.size();
- op_params.start_indices_count = dims_count;
- op_params.stop_indices_count = dims_count;
- op_params.strides_count = dims_count;
- for (int i = 0; i < dims_count; ++i) {
- op_params.start_indices[i] = start_indices[i];
- op_params.stop_indices[i] = stop_indices[i];
- op_params.strides[i] = strides[i];
- }
- op_params.begin_mask = begin_mask;
- op_params.ellipsis_mask = 0;
- op_params.end_mask = end_mask;
- op_params.new_axis_mask = 0;
- op_params.shrink_axis_mask = shrink_axis_mask;
- return op_params;
- }
- } // namespace strided_slice
- } // namespace tflite
- #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
|