wasi_nn.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. /**
  6. * Following definition from:
  7. * [Oct 25th, 2022]
  8. * https://github.com/WebAssembly/wasi-nn/blob/0f77c48ec195748990ff67928a4b3eef5f16c2de/wasi-nn.wit.md
  9. */
  10. #ifndef WASI_NN_H
  11. #define WASI_NN_H
  12. #include <stdint.h>
  13. #include "wasi_nn_types.h"
  14. #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
  15. #define WASI_NN_IMPORT(name) \
  16. __attribute__((import_module("wasi_ephemeral_nn"), import_name(name)))
  17. #else
  18. #define WASI_NN_IMPORT(name) \
  19. __attribute__((import_module("wasi_nn"), import_name(name)))
  20. #warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
  21. #endif
  22. /**
  23. * @brief Load an opaque sequence of bytes to use for inference.
  24. *
  25. * @param builder Model builder.
  26. * @param builder_len The size of model builder.
  27. * @param encoding Model encoding.
  28. * @param target Execution target.
  29. * @param g Graph.
  30. * @return wasi_nn_error Execution status.
  31. */
  32. #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
  33. WASI_NN_ERROR_TYPE
  34. WASI_NN_NAME(load)
  35. (WASI_NN_NAME(graph_builder) * builder, uint32_t builder_len,
  36. WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
  37. WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
  38. #else
  39. WASI_NN_ERROR_TYPE
  40. WASI_NN_NAME(load)
  41. (WASI_NN_NAME(graph_builder_array) * builder,
  42. WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
  43. WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
  44. #endif
  45. WASI_NN_ERROR_TYPE
  46. WASI_NN_NAME(load_by_name)
  47. (const char *name, uint32_t name_len, WASI_NN_NAME(graph) * g)
  48. WASI_NN_IMPORT("load_by_name");
  49. /**
  50. * INFERENCE
  51. *
  52. */
  53. /**
  54. * @brief Create an execution instance of a loaded graph.
  55. *
  56. * @param g Graph.
  57. * @param ctx Execution context.
  58. * @return wasi_nn_error Execution status.
  59. */
  60. WASI_NN_ERROR_TYPE
  61. WASI_NN_NAME(init_execution_context)
  62. (WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx)
  63. WASI_NN_IMPORT("init_execution_context");
  64. /**
  65. * @brief Define the inputs to use for inference.
  66. *
  67. * @param ctx Execution context.
  68. * @param index Input tensor index.
  69. * @param tensor Input tensor.
  70. * @return wasi_nn_error Execution status.
  71. */
  72. WASI_NN_ERROR_TYPE
  73. WASI_NN_NAME(set_input)
  74. (WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
  75. WASI_NN_NAME(tensor) * tensor) WASI_NN_IMPORT("set_input");
  76. /**
  77. * @brief Compute the inference on the given inputs.
  78. *
  79. * @param ctx Execution context.
  80. * @return wasi_nn_error Execution status.
  81. */
  82. WASI_NN_ERROR_TYPE
  83. WASI_NN_NAME(compute)
  84. (WASI_NN_NAME(graph_execution_context) ctx) WASI_NN_IMPORT("compute");
  85. /**
  86. * @brief Extract the outputs after inference.
  87. *
  88. * @param ctx Execution context.
  89. * @param index Output tensor index.
  90. * @param output_tensor Buffer where output tensor with index `index` is
  91. * copied.
  92. * @param output_tensor_size Pointer to `output_tensor` maximum size.
  93. * After the function call it is updated with the
  94. * copied number of bytes.
  95. * @return wasi_nn_error Execution status.
  96. */
  97. #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
  98. WASI_NN_ERROR_TYPE
  99. WASI_NN_NAME(get_output)
  100. (WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
  101. uint8_t *output_tensor, uint32_t output_tensor_max_size,
  102. uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
  103. #else
  104. WASI_NN_ERROR_TYPE
  105. WASI_NN_NAME(get_output)
  106. (graph_execution_context ctx, uint32_t index, uint8_t *output_tensor,
  107. uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
  108. #endif
  109. #endif