llm.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /*
  2. * Copyright (c) 2006-2025, RT-Thread Development Team
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Change Logs:
  7. * Date Author Notes
  8. * 2025/02/01 Rbb666 Add license info
  9. * 2025/02/10 CXSforHPU Add llm history support
  10. */
  11. #include "llm.h"
  12. #include "shell.h"
  13. static struct llm_obj handle = {0};
  14. static rt_uint8_t llm_getc(void)
  15. {
  16. rt_uint8_t ch = 0;
  17. while (rt_device_read(handle.device, (-1), &ch, 1) != 1)
  18. {
  19. rt_sem_take(&(handle.rx_sem), RT_WAITING_FOREVER);
  20. }
  21. return ch;
  22. }
  23. static rt_err_t llm_rxcb(rt_device_t dev, rt_size_t size)
  24. {
  25. return rt_sem_release(&(handle.rx_sem));
  26. }
  27. static rt_bool_t llm_handle_history(const char *prompt)
  28. {
  29. rt_kprintf("\033[2K\r");
  30. rt_kprintf("%s%s", prompt, handle.line);
  31. return RT_FALSE;
  32. }
  33. static void llm_push_history(void)
  34. {
  35. if (handle.line_position > 0)
  36. {
  37. if (handle.history_count >= LLM_HISTORY_LINES)
  38. {
  39. if (rt_memcmp(&handle.llm_history[LLM_HISTORY_LINES - 1], handle.line, PKG_LLM_CMD_BUFFER_SIZE))
  40. {
  41. int index;
  42. for (index = 0; index < FINSH_HISTORY_LINES - 1; index++)
  43. {
  44. rt_memcpy(&handle.llm_history[index][0], &handle.llm_history[index + 1][0], PKG_LLM_CMD_BUFFER_SIZE);
  45. }
  46. rt_memset(&handle.llm_history[index][0], 0, PKG_LLM_CMD_BUFFER_SIZE);
  47. rt_memcpy(&handle.llm_history[index][0], handle.line, handle.line_position);
  48. handle.history_count = LLM_HISTORY_LINES;
  49. }
  50. }
  51. else
  52. {
  53. if (handle.history_count == 0 || rt_memcmp(&handle.llm_history[handle.history_count - 1], handle.line, PKG_LLM_CMD_BUFFER_SIZE))
  54. {
  55. handle.history_current = handle.history_count;
  56. rt_memset(&handle.llm_history[handle.history_count][0], 0, PKG_LLM_CMD_BUFFER_SIZE);
  57. rt_memcpy(&handle.llm_history[handle.history_count][0], handle.line, handle.line_position);
  58. handle.history_count++;
  59. }
  60. }
  61. }
  62. handle.history_current = handle.history_count;
  63. }
  64. int llm_readline(const char *prompt, char *buffer, int buffer_size)
  65. {
  66. rt_uint8_t ch;
  67. start:
  68. rt_kprintf(prompt);
  69. while (1)
  70. {
  71. ch = llm_getc();
  72. if (ch == 0x1b)
  73. {
  74. handle.stat = LLM_WAIT_SPEC_KEY;
  75. continue;
  76. }
  77. else if (handle.stat == LLM_WAIT_SPEC_KEY)
  78. {
  79. if (ch == 0x5b)
  80. {
  81. handle.stat = LLM_WAIT_FUNC_KEY;
  82. continue;
  83. }
  84. handle.stat = LLM_WAIT_NORMAL;
  85. }
  86. else if (handle.stat == LLM_WAIT_FUNC_KEY)
  87. {
  88. handle.stat = LLM_WAIT_NORMAL;
  89. if (ch == 0x41)
  90. {
  91. if (handle.history_current > 0)
  92. {
  93. handle.history_current--;
  94. }
  95. else
  96. {
  97. handle.history_current = 0;
  98. continue;
  99. }
  100. rt_memcpy(handle.line, &handle.llm_history[handle.history_current][0], PKG_LLM_CMD_BUFFER_SIZE);
  101. handle.line_curpos = handle.line_position = rt_strlen(handle.line);
  102. llm_handle_history(prompt);
  103. continue;
  104. }
  105. else if (ch == 0x42)
  106. {
  107. if (handle.history_current < (handle.history_count - 1))
  108. {
  109. handle.history_current++;
  110. }
  111. else
  112. {
  113. if (handle.history_count != 0)
  114. {
  115. handle.history_current = handle.history_count - 1;
  116. }
  117. else
  118. {
  119. continue;
  120. }
  121. }
  122. rt_memcpy(handle.line, &handle.llm_history[handle.history_current][0], PKG_LLM_CMD_BUFFER_SIZE);
  123. handle.line_curpos = handle.line_position = rt_strlen(handle.line);
  124. llm_handle_history(prompt);
  125. continue;
  126. }
  127. else if (ch == 0x44)
  128. {
  129. if (handle.line_curpos)
  130. {
  131. rt_kprintf("\b");
  132. handle.line_curpos--;
  133. }
  134. continue;
  135. }
  136. else if (ch == 0x43)
  137. {
  138. if (handle.line_curpos < handle.line_position)
  139. {
  140. rt_kprintf("%c", handle.line[handle.line_curpos]);
  141. handle.line_curpos++;
  142. }
  143. continue;
  144. }
  145. }
  146. if (ch == '\0' || ch == 0xFF)
  147. {
  148. continue;
  149. }
  150. else if (ch == 0x7f || ch == 0x08)
  151. {
  152. if (handle.line_curpos == 0)
  153. {
  154. continue;
  155. }
  156. handle.line_curpos--;
  157. handle.line_position--;
  158. if (handle.line_position > handle.line_curpos)
  159. {
  160. rt_memmove(&handle.line[handle.line_curpos], &handle.line[handle.line_curpos + 1],
  161. handle.line_position - handle.line_curpos);
  162. handle.line[handle.line_position] = 0;
  163. rt_kprintf("\b%s \b", &handle.line[handle.line_curpos]);
  164. int index;
  165. for (index = (handle.line_curpos); index <= (handle.line_position); index++)
  166. {
  167. rt_kprintf("\b");
  168. }
  169. }
  170. else
  171. {
  172. rt_kprintf("\b \b");
  173. handle.line[handle.line_position] = 0;
  174. }
  175. continue;
  176. }
  177. else if (ch == '\r' || ch == '\n')
  178. {
  179. llm_push_history();
  180. rt_kprintf("\n");
  181. if (handle.line_position == 0)
  182. {
  183. goto start;
  184. }
  185. else
  186. {
  187. rt_uint8_t temp = handle.line_position;
  188. rt_strncpy(buffer, handle.line, handle.line_position);
  189. rt_memset(handle.line, 0x00, sizeof(handle.line));
  190. buffer[handle.line_position] = 0;
  191. handle.line_curpos = handle.line_position = 0;
  192. return temp;
  193. }
  194. }
  195. else if (ch == 0x04)
  196. {
  197. if (handle.line_position == 0)
  198. {
  199. return 0;
  200. }
  201. else
  202. {
  203. continue;
  204. }
  205. }
  206. else if (ch == '\t')
  207. {
  208. continue;
  209. }
  210. if (handle.line_position >= PKG_LLM_CMD_BUFFER_SIZE)
  211. {
  212. continue;
  213. }
  214. if (handle.line_curpos < handle.line_position)
  215. {
  216. rt_memmove(&handle.line[handle.line_curpos + 1], &handle.line[handle.line_curpos],
  217. handle.line_position - handle.line_curpos);
  218. handle.line[handle.line_curpos] = ch;
  219. rt_kprintf("%s", &handle.line[handle.line_curpos]);
  220. int index;
  221. for (index = (handle.line_curpos); index < (handle.line_position); index++)
  222. {
  223. rt_kprintf("\b");
  224. }
  225. }
  226. else
  227. {
  228. handle.line[handle.line_position] = ch;
  229. rt_kprintf("%c", ch);
  230. }
  231. ch = 0;
  232. handle.line_curpos++;
  233. handle.line_position++;
  234. }
  235. }
  236. static void llm_run(void *p)
  237. {
  238. char input_buffer[PKG_LLM_CMD_BUFFER_SIZE] = {0};
  239. const char *device_name = RT_CONSOLE_DEVICE_NAME;
  240. handle.device = rt_device_find(device_name);
  241. if (handle.device == RT_NULL)
  242. {
  243. LLM_DBG("The msh device find failed.\n");
  244. return;
  245. }
  246. handle.rx_indicate = handle.device->rx_indicate;
  247. rt_device_set_rx_indicate(handle.device, llm_rxcb);
  248. if (handle.argc == 1)
  249. {
  250. rt_kprintf("\nPress CTRL+D to exit llm shell.\n");
  251. }
  252. while (1)
  253. {
  254. int length = llm_readline("Enter command: ", input_buffer, PKG_LLM_CMD_BUFFER_SIZE);
  255. if (length == 0)
  256. {
  257. rt_kprintf("Exit terminal.\n");
  258. break;
  259. }
  260. else if (length > 0)
  261. {
  262. #ifdef PKG_LLMCHAT_HISTORY_PAYLOAD
  263. add_message2messages(input_buffer, "user", &handle);
  264. {
  265. char *result = handle.get_answer(&handle, handle.messages);
  266. add_message2messages(result, "assistant", &handle);
  267. }
  268. #else
  269. add_message2messages(input_buffer, "user", &handle);
  270. {
  271. char *result = handle.get_answer(&handle, handle.messages);
  272. rt_free(result);
  273. clear_messages(&handle);
  274. }
  275. #endif
  276. }
  277. else
  278. {
  279. rt_kprintf("No valid input.\n");
  280. }
  281. rt_memset(input_buffer, 0, sizeof(input_buffer));
  282. }
  283. rt_sem_detach(&(handle.rx_sem));
  284. rt_device_set_rx_indicate(handle.device, handle.rx_indicate);
  285. rt_kprintf(FINSH_PROMPT);
  286. }
  287. static int llm2rtt(int argc, char **argv)
  288. {
  289. static rt_bool_t history_init = RT_FALSE;
  290. if (history_init == RT_FALSE)
  291. {
  292. rt_memset(&handle, 0x00, sizeof(struct llm_obj));
  293. history_init = RT_TRUE;
  294. }
  295. else
  296. {
  297. handle.stat = LLM_WAIT_NORMAL;
  298. handle.argc = 0;
  299. rt_memset(handle.line, 0x00, PKG_LLM_CMD_BUFFER_SIZE);
  300. handle.line_position = 0;
  301. handle.line_curpos = 0;
  302. handle.device = RT_NULL;
  303. handle.rx_indicate = RT_NULL;
  304. handle.stream_cb = RT_NULL;
  305. handle.stream_user_data = RT_NULL;
  306. }
  307. rt_sem_init(&(handle.rx_sem), "llm_rxsem", 0, RT_IPC_FLAG_FIFO);
  308. handle.argc = argc;
  309. handle.get_answer = get_llm_answer;
  310. handle.stream_cb = RT_NULL;
  311. handle.stream_user_data = RT_NULL;
  312. if (!cJSON_IsArray(handle.messages))
  313. {
  314. handle.messages = cJSON_CreateArray();
  315. }
  316. #if defined(RT_VERSION_CHECK) && (RTTHREAD_VERSION >= RT_VERSION_CHECK(5, 1, 0))
  317. rt_uint8_t prio = RT_SCHED_PRIV(rt_thread_self()).current_priority + 1;
  318. #else
  319. rt_uint8_t prio = rt_thread_self()->current_priority + 1;
  320. #endif
  321. rt_err_t result = rt_thread_init(&handle.thread,
  322. "llm_td",
  323. llm_run, RT_NULL,
  324. &handle.thread_stack[0], sizeof(handle.thread_stack),
  325. prio, 10);
  326. if (result != RT_EOK)
  327. {
  328. rt_sem_detach(&(handle.rx_sem));
  329. LLM_DBG("The llm interpreter thread create failed.\n");
  330. return RT_ERROR;
  331. }
  332. rt_thread_startup(&handle.thread);
  333. return RT_EOK;
  334. }
  335. MSH_CMD_EXPORT_ALIAS(llm2rtt, llm, llm Interactive Terminal.);