| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
- *
- * SPDX-License-Identifier: Apache-2.0dim_dst_width
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_max_pool_s8_opt.c
- * Description: Pooling function implementations
- *
- * $Date: February 27, 2020
- * $Revision: V.1.0.1
- *
- * Target Processor: Cortex-M
- *
- * -------------------------------------------------------------------- */
- #include "arm_math.h"
- #include "arm_nnfunctions.h"
- #if defined(ARM_MATH_DSP)
- static void compare_and_replace_if_larger_q7(q7_t *base,
- const q7_t *target,
- const uint16_t length)
- {
- #if defined(ARM_MATH_MVEI)
- int loop_count = length / 16;
- while (loop_count)
- {
- const int8x16_t op_1 = vldrbq_s8(base);
- const int8x16_t op_2 = vldrbq_s8(target);
- const int8x16_t max = vmaxq_s8(op_1, op_2);
- vstrbq_s8(base, max);
- base += 16;
- target += 16;
- loop_count--;
- }
- if (((length & 0xF) / 8) > 0)
- {
- const int16x8_t op_1 = vldrbq_s16(base);
- const int16x8_t op_2 = vldrbq_s16(target);
- const int16x8_t max = vmaxq_s16(op_1, op_2);
- vstrbq_s16(base, max);
- base += 8;
- target += 8;
- }
- for (int i = 0; i < (length & 7); i++)
- {
- if (target[i] > base[i])
- {
- base[i] = target[i];
- }
- }
- #else
- q7_t *dst = base;
- const q7_t *src = target;
- union arm_nnword ref_max;
- union arm_nnword comp_max;
- int32_t cnt = length >> 2;
- while (cnt > 0l)
- {
- ref_max.word = arm_nn_read_q7x4(dst);
- comp_max.word = arm_nn_read_q7x4_ia(&src);
- if (comp_max.bytes[0] > ref_max.bytes[0])
- {
- ref_max.bytes[0] = comp_max.bytes[0];
- }
- if (comp_max.bytes[1] > ref_max.bytes[1])
- {
- ref_max.bytes[1] = comp_max.bytes[1];
- }
- if (comp_max.bytes[2] > ref_max.bytes[2])
- {
- ref_max.bytes[2] = comp_max.bytes[2];
- }
- if (comp_max.bytes[3] > ref_max.bytes[3])
- {
- ref_max.bytes[3] = comp_max.bytes[3];
- }
- write_q7x4_ia(&dst, ref_max.word);
- cnt--;
- }
- cnt = length & 0x3;
- while (cnt > 0l)
- {
- if (*src > *dst)
- {
- *dst = *src;
- }
- dst++;
- src++;
- cnt--;
- }
- #endif
- }
- static void clamp_output(q7_t *source, const uint16_t length, const int32_t act_min, const int32_t act_max)
- {
- #if defined(ARM_MATH_MVEI)
- int cnt = length / 16;
- while (cnt > 0)
- {
- const int8x16_t src = vldrbq_s8(source);
- int8x16_t res = vmaxq_s8(src, vdupq_n_s8((int8_t)act_min));
- res = vminq_s8(src, vdupq_n_s8((int8_t)act_max));
- vstrbq_s8(source, res);
- source += 16;
- cnt--;
- }
- if (((length & 0xF) / 8) > 0)
- {
- const int16x8_t src = vldrbq_s16(source);
- int16x8_t res = vmaxq_s16(src, vdupq_n_s16((int16_t)act_min));
- res = vminq_s16(src, vdupq_n_s16((int16_t)act_max));
- vstrbq_s16(source, res);
- source += 8;
- }
- cnt = length & 7;
- while (cnt > 0)
- {
- int32_t comp = *source;
- comp = MAX(comp, act_min);
- comp = MIN(comp, act_max);
- *source++ = (int8_t)comp;
- cnt--;
- }
- #else
- union arm_nnword in;
- int32_t cnt = length >> 2;
- while (cnt > 0l)
- {
- in.word = arm_nn_read_q7x4(source);
- in.bytes[0] = MAX(in.bytes[0], act_min);
- in.bytes[0] = MIN(in.bytes[0], act_max);
- in.bytes[1] = MAX(in.bytes[1], act_min);
- in.bytes[1] = MIN(in.bytes[1], act_max);
- in.bytes[2] = MAX(in.bytes[2], act_min);
- in.bytes[2] = MIN(in.bytes[2], act_max);
- in.bytes[3] = MAX(in.bytes[3], act_min);
- in.bytes[3] = MIN(in.bytes[3], act_max);
- write_q7x4_ia(&source, in.word);
- cnt--;
- }
- cnt = length & 0x3;
- while (cnt > 0l)
- {
- int32_t comp = *source;
- comp = MAX(comp, act_min);
- comp = MIN(comp, act_max);
- *source++ = (int8_t)comp;
- cnt--;
- }
- #endif
- }
- #endif
- /**
- * @ingroup groupNN
- */
- /**
- * @addtogroup Pooling
- * @{
- */
- /*
- * Optimized s8 max pooling function
- *
- * Refer to header file for details.
- *
- */
- arm_status arm_max_pool_s8_opt(const uint16_t input_y,
- const uint16_t input_x,
- const uint16_t output_y,
- const uint16_t output_x,
- const uint16_t stride_y,
- const uint16_t stride_x,
- const uint16_t kernel_y,
- const uint16_t kernel_x,
- const uint16_t pad_y,
- const uint16_t pad_x,
- const int8_t act_min,
- const int8_t act_max,
- const uint16_t depth,
- int8_t *src,
- int16_t *tmp_buffer,
- int8_t *dst)
- {
- #if defined(ARM_MATH_DSP)
- /* Run the following code for Cortex-M4 and Cortex-M7 */
- (void)tmp_buffer;
- int32_t i_x, i_y;
- /* first does the pooling along x axis */
- for (i_y = 0; i_y < input_y; i_y++)
- {
- for (i_x = 0; i_x < output_x; i_x++)
- {
- /* for each output sample */
- q7_t *target = src + (i_y * input_x + i_x) * depth;
- q7_t *win_start;
- q7_t *win_stop;
- const int32_t x_origin = i_x * stride_x - pad_x;
- if (x_origin < 0)
- {
- win_start = target;
- }
- else
- {
- win_start = src + (i_y * input_x + x_origin) * depth;
- }
- if (x_origin + kernel_x >= input_x)
- {
- win_stop = src + (i_y * input_x + input_x) * depth;
- }
- else
- {
- win_stop = src + (i_y * input_x + x_origin + kernel_x) * depth;
- }
- /* first step is to copy over initial data(along channel) along the channel in x direction */
- memmove(target, win_start, depth);
- /* Move over to next element along x axis and compare with the base(target) */
- win_start += depth;
- for (; win_start < win_stop; win_start += depth)
- {
- compare_and_replace_if_larger_q7(target, win_start, depth);
- }
- }
- }
- /* then does the pooling along y axis */
- for (i_y = 0; i_y < output_y; i_y++)
- {
- /* for each output row */
- q7_t *target = dst + i_y * output_x * depth;
- q7_t *row_start;
- q7_t *row_end;
- const int32_t y_origin = i_y * stride_y - pad_y;
- /* setting the starting row */
- if (y_origin < 0)
- {
- row_start = src;
- }
- else
- {
- row_start = src + y_origin * input_x * depth;
- }
- /* setting the stopping row */
- if (y_origin + kernel_y >= input_y)
- {
- row_end = src + input_y * input_x * depth;
- }
- else
- {
- row_end = src + (y_origin + kernel_y) * input_x * depth;
- }
- /* copy over the complete first row. */
- memmove(target, row_start, output_x * depth);
- /* move over to next row and compare with the base row (target)*/
- row_start += depth * input_x;
- for (; row_start < row_end; row_start += input_x * depth)
- {
- compare_and_replace_if_larger_q7(target, row_start, output_x * depth);
- }
- }
- clamp_output(dst, output_x * output_y * depth, act_min, act_max);
- #else
- /* Pure C implementation */
- arm_max_pool_s8(input_y,
- input_x,
- output_y,
- output_x,
- stride_y,
- stride_x,
- kernel_y,
- kernel_x,
- pad_y,
- pad_x,
- act_min,
- act_max,
- depth,
- src,
- tmp_buffer,
- dst);
- #endif
- return ARM_MATH_SUCCESS;
- }
- /**
- * @} end of Pooling group
- */
|