wasi_nn.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. /**
  6. * Following definition from:
  7. * [Oct 25th, 2022]
  8. * https://github.com/WebAssembly/wasi-nn/blob/0f77c48ec195748990ff67928a4b3eef5f16c2de/wasi-nn.wit.md
  9. */
  10. #ifndef WASI_NN_H
  11. #define WASI_NN_H
  12. #include <stdint.h>
  13. #include "wasi_nn_types.h"
  14. /**
  15. * @brief Load an opaque sequence of bytes to use for inference.
  16. *
  17. * @param builder Model builder.
  18. * @param encoding Model encoding.
  19. * @param target Execution target.
  20. * @param g Graph.
  21. * @return wasi_nn_error Execution status.
  22. */
  23. wasi_nn_error
  24. load(graph_builder_array *builder, graph_encoding encoding,
  25. execution_target target, graph *g)
  26. __attribute__((import_module("wasi_nn")));
  27. wasi_nn_error
  28. load_by_name(const char *name, graph *g)
  29. __attribute__((import_module("wasi_nn")));
  30. /**
  31. * INFERENCE
  32. *
  33. */
  34. /**
  35. * @brief Create an execution instance of a loaded graph.
  36. *
  37. * @param g Graph.
  38. * @param ctx Execution context.
  39. * @return wasi_nn_error Execution status.
  40. */
  41. wasi_nn_error
  42. init_execution_context(graph g, graph_execution_context *ctx)
  43. __attribute__((import_module("wasi_nn")));
  44. /**
  45. * @brief Define the inputs to use for inference.
  46. *
  47. * @param ctx Execution context.
  48. * @param index Input tensor index.
  49. * @param tensor Input tensor.
  50. * @return wasi_nn_error Execution status.
  51. */
  52. wasi_nn_error
  53. set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
  54. __attribute__((import_module("wasi_nn")));
  55. /**
  56. * @brief Compute the inference on the given inputs.
  57. *
  58. * @param ctx Execution context.
  59. * @return wasi_nn_error Execution status.
  60. */
  61. wasi_nn_error
  62. compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));
  63. /**
  64. * @brief Extract the outputs after inference.
  65. *
  66. * @param ctx Execution context.
  67. * @param index Output tensor index.
  68. * @param output_tensor Buffer where output tensor with index `index` is
  69. * copied.
  70. * @param output_tensor_size Pointer to `output_tensor` maximum size.
  71. * After the function call it is updated with the
  72. * copied number of bytes.
  73. * @return wasi_nn_error Execution status.
  74. */
  75. wasi_nn_error
  76. get_output(graph_execution_context ctx, uint32_t index,
  77. tensor_data output_tensor, uint32_t *output_tensor_size)
  78. __attribute__((import_module("wasi_nn")));
  79. #endif