rt_ai_mpython.c 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /*
  2. * Copyright (c) 2006-2021, RT-Thread Development Team
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Change Logs:
  7. * Date Author Notes
  8. * 2022-02-28 liqiwen the first version
  9. */
  10. #include <rtthread.h>
  11. #ifdef MICROPYTHON_USING_USEREXTMODS
  12. #include <rt_ai.h>
  13. #include "py/qstr.h"
  14. #include "py/obj.h"
  15. #include "py/runtime.h"
  16. #include "model_paser_helper.h"
  17. #define __debug(_val) mp_printf(&mp_plat_print,"%s,%d," #_val "=%d\n",__FILE__,__LINE__,_val);
  18. typedef struct
  19. {
  20. mp_obj_base_t base;
  21. rt_ai_t handle;
  22. void *buf;
  23. } py_model_obj_t;
  24. #define py_model_cobj(py_obj) ((py_obj)->handle)
  25. const mp_obj_type_t py_model_type;
  26. STATIC mp_obj_t py_rt_ai_init(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args){
  27. enum
  28. {
  29. ARG_model,
  30. ARG_size,
  31. };
  32. static const mp_arg_t allowed_args[] = {
  33. {MP_QSTR_model, MP_ARG_OBJ, {.u_obj = mp_const_none}},
  34. {MP_QSTR_size, MP_ARG_INT, {.u_int = 0}},
  35. };
  36. mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
  37. mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
  38. if(mp_obj_get_type(args[ARG_model].u_obj) != & py_model_type){
  39. mp_raise_TypeError("model type err!");
  40. }
  41. py_model_obj_t *model_obj = (py_model_obj_t *)args[ARG_model].u_obj;
  42. rt_ai_t handle = py_model_cobj(model_obj);
  43. uint32_t size = args[ARG_size].u_int;
  44. int stat = 0;
  45. if(!size){
  46. stat = rt_ai_init(handle, NULL);
  47. }
  48. else{
  49. model_obj->buf = m_malloc(size);
  50. stat = rt_ai_init(handle, model_obj->buf);
  51. }
  52. if(stat){
  53. mp_raise_OSError(stat);
  54. }
  55. if(!handle->info.input_n && !handle->info.input_n){
  56. handle->info.input_n = inputs_size(handle); __debug(handle->info.input_n);
  57. handle->info.output_n = outputs_size(handle); __debug(handle->info.output_n);
  58. for(int i=0; i<handle->info.input_n; i++){
  59. handle->info.input_n_stack[i] = inputs_n_bytes(handle, i); __debug(handle->info.input_n_stack[i]);
  60. }
  61. for(int i=0; i<handle->info.output_n; i++){
  62. handle->info.output_n_stack[i] = outputs_n_bytes(handle, i);__debug(handle->info.output_n_stack[i]);
  63. }
  64. }
  65. return mp_const_none;
  66. }
  67. STATIC MP_DEFINE_CONST_FUN_OBJ_KW(py_rt_ai_init_obj, 1, py_rt_ai_init);
  68. static int ai_done(void *ctx)
  69. {
  70. *((uint32_t*)ctx)= 1;
  71. return 0;
  72. }
  73. STATIC mp_obj_t py_rt_ai_run(mp_obj_t _obj, mp_obj_t data){
  74. if(mp_obj_get_type(_obj) != & py_model_type){
  75. mp_raise_TypeError("model type err!");
  76. }
  77. py_model_obj_t *model_obj = (py_model_obj_t*)_obj;
  78. volatile uint32_t g_ai_done_flag = 0;
  79. mp_buffer_info_t bufinfo;
  80. if (mp_get_buffer(data, &bufinfo, MP_BUFFER_READ)){
  81. py_model_cobj(model_obj)->input[0] = bufinfo.buf;
  82. rt_ai_run(py_model_cobj(model_obj), ai_done, &g_ai_done_flag);
  83. while(!g_ai_done_flag);
  84. }
  85. else{
  86. mp_raise_ValueError("now only support image_t");
  87. }
  88. return mp_const_none;
  89. }
  90. static MP_DEFINE_CONST_FUN_OBJ_2(py_rt_ai_run_obj, py_rt_ai_run);
  91. STATIC mp_obj_t py_rt_ai_output(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
  92. {
  93. enum
  94. {
  95. ARG_model,
  96. ARG_index,
  97. ARG_getlist,
  98. };
  99. static const mp_arg_t allowed_args[] = {
  100. {MP_QSTR_model, MP_ARG_OBJ, {.u_obj = mp_const_none}},
  101. {MP_QSTR_index, MP_ARG_INT, {.u_int = 0}},
  102. {MP_QSTR_getlist, MP_ARG_BOOL, {.u_bool = 1}},
  103. };
  104. mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
  105. mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
  106. if(mp_obj_get_type(args[ARG_model].u_obj) != & py_model_type){
  107. mp_raise_TypeError("model type err!");
  108. }
  109. py_model_obj_t *model_obj = (py_model_obj_t *)args[ARG_model].u_obj;
  110. rt_ai_t handle = py_model_cobj(model_obj);
  111. mp_obj_list_t *ret_list = NULL;
  112. mp_int_t index = args[ARG_index].u_int;
  113. if((index < 0) || index > handle->info.output_n)
  114. {
  115. mp_raise_ValueError("excess of output");
  116. }
  117. float *out;
  118. out = rt_ai_output(handle, index);
  119. if(args[ARG_getlist].u_bool){
  120. ret_list = m_new(mp_obj_list_t, 1);
  121. mp_obj_list_init(ret_list, 0);
  122. for(int j = 0; j < (handle->info.output_n_stack[index])/sizeof(float); j++){
  123. mp_obj_list_append(ret_list, mp_obj_new_float(out[j]) );
  124. }
  125. return MP_OBJ_FROM_PTR(ret_list);
  126. }
  127. return mp_const_none;
  128. }
  129. STATIC MP_DEFINE_CONST_FUN_OBJ_KW(py_rt_ai_output_obj, 1, py_rt_ai_output);
  130. STATIC mp_obj_t make_new()
  131. {
  132. py_model_obj_t *self = m_new_obj_with_finaliser(py_model_obj_t);
  133. self->base.type = &py_model_type;
  134. self->handle = RT_NULL;
  135. self->buf = NULL;
  136. return self;
  137. }
  138. STATIC const mp_rom_map_elem_t locals_dict_table[] =
  139. {
  140. {MP_ROM_QSTR(MP_QSTR_name), MP_OBJ_NEW_QSTR(MP_QSTR_name)},
  141. };
  142. STATIC MP_DEFINE_CONST_DICT(locals_dict, locals_dict_table);
  143. const mp_obj_type_t py_model_type =
  144. {
  145. {&mp_type_type},
  146. .name = MP_QSTR_model,
  147. // .make_new = make_new,
  148. .locals_dict = (mp_obj_t)&locals_dict
  149. };
  150. STATIC mp_obj_t py_rt_ai_find(mp_obj_t name){
  151. volatile uint32_t g_ai_done_flag = 0;
  152. if(mp_obj_get_type(name) == &mp_type_str){
  153. const char *model_name = mp_obj_str_get_str(name);
  154. rt_ai_t handle = rt_ai_find(model_name);
  155. if(!handle){
  156. mp_raise_ValueError("error! not find model");
  157. }
  158. py_model_obj_t *model_obj = m_new_obj(py_model_obj_t);
  159. model_obj->handle = handle;
  160. model_obj->base.type = &py_model_type;
  161. model_obj->buf = NULL;
  162. return MP_OBJ_FROM_PTR(model_obj);
  163. }
  164. else{
  165. mp_raise_TypeError("please type model name str!");
  166. }
  167. return mp_const_none;
  168. }
  169. static MP_DEFINE_CONST_FUN_OBJ_1(py_rt_ai_find_obj, py_rt_ai_find);
  170. STATIC mp_obj_t py_rt_ai_load(mp_obj_t buffer, mp_obj_t name){
  171. if(mp_obj_get_type(name) != &mp_type_str){
  172. mp_raise_TypeError("please type model name str!");
  173. }
  174. mp_buffer_info_t bufinfo;
  175. if (!mp_get_buffer(buffer, &bufinfo, MP_BUFFER_READ)){
  176. mp_raise_ValueError("get kmodel buffer error!");
  177. }
  178. const char *model_name = mp_obj_str_get_str(name);
  179. rt_ai_t handle = backend_k210_kpu_constructor_helper(bufinfo.buf, model_name);
  180. py_model_obj_t *model_obj = m_new_obj(py_model_obj_t);
  181. model_obj->base.type = &py_model_type;
  182. model_obj->handle = handle;
  183. model_obj->buf = RT_NULL;
  184. return MP_OBJ_FROM_PTR(model_obj);
  185. }
  186. static MP_DEFINE_CONST_FUN_OBJ_2(py_rt_ai_load_obj, py_rt_ai_load);
  187. STATIC mp_obj_t py_rt_ai_free(mp_obj_t _obj){
  188. if(mp_obj_get_type(_obj) != & py_model_type){
  189. mp_raise_TypeError("model type err!");
  190. }
  191. py_model_obj_t *model_obj = (py_model_obj_t*)_obj;
  192. backend_k210_kpu_kmodel_free(model_obj->handle);
  193. return mp_const_none;
  194. }
  195. static MP_DEFINE_CONST_FUN_OBJ_1(py_rt_ai_free_obj, py_rt_ai_free);
  196. STATIC const mp_rom_map_elem_t rt_ak_module_globals_table[] = {
  197. { MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_RT_AK) },
  198. { MP_OBJ_NEW_QSTR(MP_QSTR_model), MP_ROM_PTR(&py_model_type)},
  199. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_load), MP_ROM_PTR(&py_rt_ai_load_obj) },
  200. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_find), MP_ROM_PTR(&py_rt_ai_find_obj) },
  201. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_init), MP_ROM_PTR(&py_rt_ai_init_obj) },
  202. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_run), MP_ROM_PTR(&py_rt_ai_run_obj) },
  203. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_output), MP_ROM_PTR(&py_rt_ai_output_obj) },
  204. { MP_OBJ_NEW_QSTR(MP_QSTR_ai_free), MP_ROM_PTR(&py_rt_ai_free_obj) },
  205. };
  206. STATIC MP_DEFINE_CONST_DICT(rt_ak_module_globals, rt_ak_module_globals_table);
  207. const mp_obj_module_t rt_ak_module = {
  208. .base = { &mp_type_module },
  209. .globals = (mp_obj_dict_t*)&rt_ak_module_globals,
  210. };
  211. #endif //PKG_USING_MICROPYTHON