arm_lstm_unidirectional_s8.c 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_lstm_unidirectional_s8.c
  21. * Description: S8 LSTM function with S16 gate output
  22. *
  23. * $Date: 08 February 2024
  24. * $Revision: V.1.1.0
  25. *
  26. * Target Processor: Cortex-M processors
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup Public
  33. */
  34. /**
  35. * @addtogroup LSTM
  36. * @{
  37. */
  38. /*
  39. * S8 LSTM function for TensorFlow Lite with S16 gate output
  40. *
  41. * Refer to header file for details.
  42. *
  43. */
  44. arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
  45. int8_t *output,
  46. const cmsis_nn_lstm_params *params,
  47. cmsis_nn_lstm_context *buffers)
  48. {
  49. int8_t *hidden_in = NULL;
  50. memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
  51. if (params->time_major)
  52. {
  53. // First dimension is time, input/output for each time step is stored continously in memory
  54. for (int t = 0; t < params->time_steps; t++)
  55. {
  56. const int8_t *data_in = input + (t * params->batch_size * params->input_size);
  57. int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size);
  58. arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, 1);
  59. if (status != ARM_CMSIS_NN_SUCCESS)
  60. {
  61. return status;
  62. }
  63. // Output is used as recurrent input/hidden state for the next timestep.
  64. hidden_in = &hidden_out[0];
  65. }
  66. }
  67. else
  68. {
  69. // First dimension is time, add batch_offset to jump in memory for each batch
  70. for (int t = 0; t < params->time_steps; t++)
  71. {
  72. const int8_t *data_in = input + (t * params->input_size);
  73. int8_t *hidden_out = output + (t * params->hidden_size);
  74. arm_cmsis_nn_status status =
  75. arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
  76. if (status != ARM_CMSIS_NN_SUCCESS)
  77. {
  78. return status;
  79. }
  80. // Output is used as recurrent input/hidden state for the next timestep.
  81. hidden_in = &hidden_out[0];
  82. }
  83. }
  84. return ARM_CMSIS_NN_SUCCESS;
  85. }
  86. /**
  87. * @} end of LSTM group
  88. */