net.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #ifndef NCNN_NET_H
  15. #define NCNN_NET_H
  16. #include "blob.h"
  17. #include "layer.h"
  18. #include "mat.h"
  19. #include "option.h"
  20. #include "platform.h"
  21. #if NCNN_PLATFORM_API
  22. #if __ANDROID_API__ >= 9
  23. #include <android/asset_manager.h>
  24. #endif // __ANDROID_API__ >= 9
  25. #endif // NCNN_PLATFORM_API
  26. namespace ncnn {
  27. #if NCNN_VULKAN
  28. class VkCompute;
  29. #endif // NCNN_VULKAN
  30. class DataReader;
  31. class Extractor;
  32. class NetPrivate;
  33. class NCNN_EXPORT Net
  34. {
  35. public:
  36. // empty init
  37. Net();
  38. // clear and destroy
  39. virtual ~Net();
  40. public:
  41. // option can be changed before loading
  42. Option opt;
  43. #if NCNN_VULKAN
  44. // set gpu device by index
  45. void set_vulkan_device(int device_index);
  46. // set gpu device by device handle, no owner transfer
  47. void set_vulkan_device(const VulkanDevice* vkdev);
  48. const VulkanDevice* vulkan_device() const;
  49. #endif // NCNN_VULKAN
  50. #if NCNN_STRING
  51. // register custom layer by layer type name
  52. // return 0 if success
  53. int register_custom_layer(const char* type, layer_creator_func creator, layer_destroyer_func destroyer = 0, void* userdata = 0);
  54. virtual int custom_layer_to_index(const char* type);
  55. #endif // NCNN_STRING
  56. // register custom layer by layer type
  57. // return 0 if success
  58. int register_custom_layer(int index, layer_creator_func creator, layer_destroyer_func destroyer = 0, void* userdata = 0);
  59. #if NCNN_STRING
  60. int load_param(const DataReader& dr);
  61. #endif // NCNN_STRING
  62. int load_param_bin(const DataReader& dr);
  63. int load_model(const DataReader& dr);
  64. #if NCNN_STDIO
  65. #if NCNN_STRING
  66. // load network structure from plain param file
  67. // return 0 if success
  68. int load_param(FILE* fp);
  69. int load_param(const char* protopath);
  70. int load_param_mem(const char* mem);
  71. #endif // NCNN_STRING
  72. // load network structure from binary param file
  73. // return 0 if success
  74. int load_param_bin(FILE* fp);
  75. int load_param_bin(const char* protopath);
  76. // load network weight data from model file
  77. // return 0 if success
  78. int load_model(FILE* fp);
  79. int load_model(const char* modelpath);
  80. #endif // NCNN_STDIO
  81. // load network structure from external memory
  82. // memory pointer must be 32-bit aligned
  83. // return bytes consumed
  84. int load_param(const unsigned char* mem);
  85. // reference network weight data from external memory
  86. // weight data is not copied but referenced
  87. // so external memory should be retained when used
  88. // memory pointer must be 32-bit aligned
  89. // return bytes consumed
  90. int load_model(const unsigned char* mem);
  91. #if NCNN_PLATFORM_API
  92. #if __ANDROID_API__ >= 9
  93. #if NCNN_STRING
  94. // convenient load network structure from android asset plain param file
  95. int load_param(AAsset* asset);
  96. int load_param(AAssetManager* mgr, const char* assetpath);
  97. #endif // NCNN_STRING
  98. // convenient load network structure from android asset binary param file
  99. int load_param_bin(AAsset* asset);
  100. int load_param_bin(AAssetManager* mgr, const char* assetpath);
  101. // convenient load network weight data from android asset model file
  102. int load_model(AAsset* asset);
  103. int load_model(AAssetManager* mgr, const char* assetpath);
  104. #endif // __ANDROID_API__ >= 9
  105. #endif // NCNN_PLATFORM_API
  106. // unload network structure and weight data
  107. void clear();
  108. // construct an Extractor from network
  109. Extractor create_extractor() const;
  110. // get input/output indexes/names
  111. const std::vector<int>& input_indexes() const;
  112. const std::vector<int>& output_indexes() const;
  113. #if NCNN_STRING
  114. const std::vector<const char*>& input_names() const;
  115. const std::vector<const char*>& output_names() const;
  116. #endif
  117. const std::vector<Blob>& blobs() const;
  118. const std::vector<Layer*>& layers() const;
  119. std::vector<Blob>& mutable_blobs();
  120. std::vector<Layer*>& mutable_layers();
  121. protected:
  122. friend class Extractor;
  123. #if NCNN_STRING
  124. int find_blob_index_by_name(const char* name) const;
  125. int find_layer_index_by_name(const char* name) const;
  126. virtual Layer* create_custom_layer(const char* type);
  127. #endif // NCNN_STRING
  128. virtual Layer* create_custom_layer(int index);
  129. private:
  130. Net(const Net&);
  131. Net& operator=(const Net&);
  132. private:
  133. NetPrivate* const d;
  134. };
  135. class ExtractorPrivate;
  136. class NCNN_EXPORT Extractor
  137. {
  138. public:
  139. virtual ~Extractor();
  140. // copy
  141. Extractor(const Extractor&);
  142. // assign
  143. Extractor& operator=(const Extractor&);
  144. // clear blob mats and alloctors
  145. void clear();
  146. // enable light mode
  147. // intermediate blob will be recycled when enabled
  148. // enabled by default
  149. void set_light_mode(bool enable);
  150. // set thread count for this extractor
  151. // this will overwrite the global setting
  152. // default count is system depended
  153. void set_num_threads(int num_threads);
  154. // set blob memory allocator
  155. void set_blob_allocator(Allocator* allocator);
  156. // set workspace memory allocator
  157. void set_workspace_allocator(Allocator* allocator);
  158. #if NCNN_VULKAN
  159. void set_vulkan_compute(bool enable);
  160. void set_blob_vkallocator(VkAllocator* allocator);
  161. void set_workspace_vkallocator(VkAllocator* allocator);
  162. void set_staging_vkallocator(VkAllocator* allocator);
  163. #endif // NCNN_VULKAN
  164. #if NCNN_STRING
  165. // set input by blob name
  166. // return 0 if success
  167. int input(const char* blob_name, const Mat& in);
  168. // get result by blob name
  169. // return 0 if success
  170. // type = 0, default
  171. // type = 1, do not convert fp16/bf16 or / and packing
  172. int extract(const char* blob_name, Mat& feat, int type = 0);
  173. #endif // NCNN_STRING
  174. // set input by blob index
  175. // return 0 if success
  176. int input(int blob_index, const Mat& in);
  177. // get result by blob index
  178. // return 0 if success
  179. // type = 0, default
  180. // type = 1, do not convert fp16/bf16 or / and packing
  181. int extract(int blob_index, Mat& feat, int type = 0);
  182. #if NCNN_VULKAN
  183. #if NCNN_STRING
  184. // set input by blob name
  185. // return 0 if success
  186. int input(const char* blob_name, const VkMat& in);
  187. // get result by blob name
  188. // return 0 if success
  189. int extract(const char* blob_name, VkMat& feat, VkCompute& cmd);
  190. // set input by blob name
  191. // return 0 if success
  192. int input(const char* blob_name, const VkImageMat& in);
  193. // get result by blob name
  194. // return 0 if success
  195. int extract(const char* blob_name, VkImageMat& feat, VkCompute& cmd);
  196. #endif // NCNN_STRING
  197. // set input by blob index
  198. // return 0 if success
  199. int input(int blob_index, const VkMat& in);
  200. // get result by blob index
  201. // return 0 if success
  202. int extract(int blob_index, VkMat& feat, VkCompute& cmd);
  203. // set input by blob index
  204. // return 0 if success
  205. int input(int blob_index, const VkImageMat& in);
  206. // get result by blob index
  207. // return 0 if success
  208. int extract(int blob_index, VkImageMat& feat, VkCompute& cmd);
  209. #endif // NCNN_VULKAN
  210. protected:
  211. friend Extractor Net::create_extractor() const;
  212. Extractor(const Net* net, size_t blob_count);
  213. private:
  214. ExtractorPrivate* const d;
  215. };
  216. } // namespace ncnn
  217. #endif // NCNN_NET_H