Client.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. from langchain_openai import ChatOpenAI
  2. from langchain_mcp_adapters.tools import load_mcp_tools
  3. from langgraph.graph import StateGraph, END, add_messages
  4. from langchain_core.tools import BaseTool
  5. from langchain_core.messages import SystemMessage, BaseMessage, HumanMessage, AIMessage, ToolMessage
  6. from langgraph.types import interrupt, Command
  7. from typing import TypedDict, Annotated, List, Dict, Any, Optional
  8. from mcp import ClientSession, StdioServerParameters
  9. from mcp.client.stdio import stdio_client
  10. from MCP_config import MODEL_NAME, BASE_URL, API_KEY, MAX_RETRYIES, TIME_OUT, debug_level, TEMPERATURE, MAX_TOKENS, RECURSION_LIMIT, LOG_ENABLED
  11. from prompt.prompt_loader import load_prompt
  12. import pyfiglet
  13. import os
  14. import asyncio
  15. import json
  16. import argparse
  17. from datetime import datetime
  18. from pathlib import Path
  19. import traceback
  20. import math
  21. #定义状态类
  22. class State(TypedDict):
  23. user_input: str
  24. user_id: str
  25. messages: Annotated[List[dict], add_messages] # 消息信息
  26. # 自定义条件判断函数
  27. def should_continue(state: State) -> str:
  28. messages = state['messages']
  29. last_message = messages[-1]
  30. # 如果最后一条消息是AIMessage且有工具调用,则继续调用工具
  31. if isinstance(last_message, AIMessage) and last_message.tool_calls:
  32. return "tools"
  33. # 否则结束
  34. return "end"
  35. class SingletonMeta(type):
  36. """单例元类"""
  37. _instances = {}
  38. def __call__(cls, *args, **kwargs):
  39. if cls not in cls._instances:
  40. cls._instances[cls] = super().__call__(*args, **kwargs)
  41. return cls._instances[cls]
  42. class PythonCTranspiler(metaclass=SingletonMeta):
  43. """Python到C转换器的代理类"""
  44. def __init__(self):
  45. # 配置MCP服务器参数
  46. self.server_params = StdioServerParameters(
  47. command="python",
  48. args=["MCP_server.py"]
  49. )
  50. self.llm_with_tools = None
  51. self.mcp_tools = None
  52. self.graph = None
  53. self.session = None
  54. self.initialized = False # 明确设置初始化状态
  55. # 日志相关
  56. self.session_log_dir: Path | None = None
  57. self.llm_call_index: int = 0
  58. self.tool_call_index: int = 0
  59. self.root_log_dir: Path = Path("./logs")
  60. self._logging_enabled: bool = LOG_ENABLED
  61. # 动态工作目录 (用于隔离单次任务生成文件,避免冲突)
  62. ts = datetime.now().strftime('%Y%m%d_%H%M%S')
  63. self.session_work_dir: Path = Path("./file_create") / ts
  64. try:
  65. self.session_work_dir.mkdir(parents=True, exist_ok=True)
  66. except Exception:
  67. pass
  68. # 传递给 MCP 服务端用于写入限制
  69. os.environ['SESSION_WORK_DIR'] = str(self.session_work_dir.resolve())
  70. # 原始系统 prompt 文件名(在 main 中设置)
  71. self._base_prompt_filename: str | None = None
  72. # Token 用量累计
  73. self._usage_totals = {
  74. 'completion_tokens': 0,
  75. 'prompt_tokens': 0,
  76. 'prompt_cache_hit_tokens': 0,
  77. 'prompt_cache_miss_tokens': 0,
  78. 'total_tokens': 0,
  79. 'completion_tokens_details': {
  80. 'reasoning_tokens': 0
  81. },
  82. 'cost_yuan': 0.0
  83. }
  84. # ================= 日志辅助函数 =================
  85. def _init_session_logging(self):
  86. """初始化本次用户请求的日志目录。
  87. 只在新的用户顶层输入时调用一次。"""
  88. if not self._logging_enabled:
  89. self.session_log_dir = None
  90. return
  91. now = datetime.now()
  92. session_dir_name = f"session_{now.strftime('%Y%m%d_%H%M%S')}"
  93. self.session_log_dir = self.root_log_dir / session_dir_name
  94. try:
  95. self.session_log_dir.mkdir(parents=True, exist_ok=True)
  96. except Exception:
  97. # 目录创建失败则置空,后续写日志自动跳过
  98. self.session_log_dir = None
  99. print(f"日志目录:{self.session_log_dir}")
  100. self.llm_call_index = 0
  101. self.tool_call_index = 0
  102. def _write_log(self, filename: str, data: dict | str):
  103. """写入单个日志文件。失败时静默。"""
  104. if not self._logging_enabled or self.session_log_dir is None:
  105. return
  106. try:
  107. file_path = self.session_log_dir / filename
  108. if isinstance(data, dict):
  109. content = json.dumps(data, ensure_ascii=False, indent=2)
  110. else:
  111. content = str(data)
  112. file_path.write_text(content, encoding='utf-8')
  113. except Exception:
  114. pass
  115. def _write_log_end(self, filename: str, data: dict | str):
  116. """写入结束日志文件。"""
  117. if not self._logging_enabled or self.session_log_dir is None:
  118. return
  119. try:
  120. file_path = self.session_log_dir / filename
  121. content = str(data)
  122. file_path.write_text(content, encoding='utf-8')
  123. except Exception:
  124. pass
  125. def set_logging(self, enabled: bool):
  126. """运行时切换日志开关。若关闭则之后不再写入新的日志。"""
  127. self._logging_enabled = bool(enabled)
  128. async def initialize(self):
  129. """初始化MCP会话和工具"""
  130. # 创建stdio客户端和会话
  131. # 手动创建stdio客户端
  132. if self.initialized:
  133. return self # 如果已经初始化,直接返回
  134. self.stdio_client = stdio_client(self.server_params)
  135. self.read_stream, self.write_stream = await self.stdio_client.__aenter__()
  136. print("mcp客户端创建成功\n")
  137. # 创建会话
  138. self.session = ClientSession(self.read_stream, self.write_stream)
  139. await self.session.__aenter__()
  140. print("mcp会话创建成功\n")
  141. # 初始化连接
  142. await self.session.initialize()
  143. print("连接初始化成功\n")
  144. # 从MCP中获取工具
  145. self.mcp_tools = await load_mcp_tools(self.session)
  146. # 初始化LLM
  147. # 这里把 langchain 内置重试关闭 (max_retries=0), 由我们自定义的指数/线性混合回退控制
  148. # 避免重复重试导致延迟不可控
  149. llm = ChatOpenAI(
  150. model=MODEL_NAME,
  151. base_url=BASE_URL,
  152. api_key=API_KEY,
  153. temperature=TEMPERATURE,
  154. max_tokens=MAX_TOKENS,
  155. max_retries=0,
  156. timeout=TIME_OUT
  157. )
  158. # 将工具绑定到LLM
  159. self.llm_with_tools = llm.bind_tools(self.mcp_tools)
  160. # 创建代理图
  161. self._create_agent_graph()
  162. print("图创建成功")
  163. # 新的用户顶层输入:初始化 session 日志目录
  164. if not self.root_log_dir.exists():
  165. try:
  166. self.root_log_dir.mkdir(parents=True, exist_ok=True)
  167. except Exception:
  168. pass
  169. self._init_session_logging()
  170. self.initialized = True # 初始化完成后设置标志
  171. return self
  172. async def close(self):
  173. """关闭会话和连接"""
  174. # 按正确顺序关闭:先 session.__aexit__ 再 stdio_client.__aexit__
  175. try:
  176. if self.session:
  177. try:
  178. await self.session.__aexit__(None, None, None)
  179. except Exception as e:
  180. print(f"[WARN] 关闭 session 异常: {e}")
  181. finally:
  182. if hasattr(self, 'stdio_client') and self.stdio_client:
  183. try:
  184. await self.stdio_client.__aexit__(None, None, None)
  185. except Exception as e:
  186. print(f"[WARN] 关闭 stdio_client 异常: {e}")
  187. async def __aenter__(self):
  188. """异步上下文管理器入口"""
  189. return await self.initialize()
  190. async def __aexit__(self, exc_type, exc_val, exc_tb):
  191. """异步上下文管理器退出"""
  192. await self.close()
  193. def _create_agent_graph(self):
  194. """创建代理图"""
  195. graph_builder = StateGraph(State)
  196. # 添加节点
  197. graph_builder.add_node("agent", self._call_model)
  198. graph_builder.add_node("tools", self._call_tool)
  199. # 设置入口点
  200. graph_builder.set_entry_point("agent")
  201. # 添加条件边
  202. graph_builder.add_conditional_edges(
  203. "agent",
  204. should_continue,
  205. {
  206. "tools": "tools", # 需要工具调用
  207. "end": END # 不需要工具调用,结束
  208. }
  209. )
  210. # 从工具节点返回代理节点
  211. graph_builder.add_edge("tools", "agent")
  212. # 编译图
  213. self.graph = graph_builder.compile()
  214. async def _call_model(self, state: State):
  215. """调用模型处理状态,增加自定义重试机制。
  216. 重试策略:
  217. 1. 默认延迟序列: 1,2,5,10,30,60,120 (秒)。
  218. 2. 可通过环境变量 LLM_RETRY_DELAYS 覆盖,格式: "1,2,5" (秒,正整数/浮点)。解析失败则回退默认。
  219. 3. 仅针对网络/额度/限流/临时性错误进行重试 (OpenAI 402余额不足仍按策略尝试直到序列结束, 便于在外部补款后继续)。
  220. 4. 每次失败记录日志: retry_index, delay, error_type, error_message。
  221. 5. 最终失败写入 final_error.log 并抛出异常 (让上游 graph 终止)。
  222. """
  223. messages = state['messages']
  224. # 解析自定义延迟序列
  225. default_delays = [1, 2, 5, 10, 30, 60, 120]
  226. env_delays_raw = os.getenv('LLM_RETRY_DELAYS')
  227. if env_delays_raw:
  228. try:
  229. parsed = []
  230. for part in env_delays_raw.split(','):
  231. p = part.strip()
  232. if not p:
  233. continue
  234. val = float(p)
  235. if val <= 0:
  236. continue
  237. parsed.append(val)
  238. if parsed:
  239. default_delays = parsed
  240. except Exception:
  241. # 解析失败静默回退
  242. pass
  243. attempts = len(default_delays) + 1 # 初始立即调用 + 延迟列表
  244. last_err: Exception | None = None
  245. for attempt in range(1, attempts + 1):
  246. try:
  247. response = await self.llm_with_tools.ainvoke(messages)
  248. break # 成功
  249. except Exception as e:
  250. last_err = e
  251. # 记录失败日志
  252. log_payload = {
  253. "phase": "LLM_CALL_RETRY_ERROR",
  254. "attempt": attempt,
  255. "max_attempts": attempts,
  256. "error_type": e.__class__.__name__,
  257. "error_message": str(e),
  258. }
  259. self._write_log('llm_retry.log', log_payload)
  260. # 终端提示: 当前失败 + 下次等待 (如果还会重试)
  261. if attempt < attempts:
  262. delay = default_delays[attempt - 1] if attempt - 1 < len(default_delays) else default_delays[-1]
  263. print(f"[LLM][Retry {attempt}/{attempts}] {e.__class__.__name__}: {str(e)[:140]} -- next wait {delay}s", flush=True)
  264. else:
  265. print(f"[LLM][Retry {attempt}/{attempts}] {e.__class__.__name__}: {str(e)[:140]} -- no more retries", flush=True)
  266. if attempt == attempts:
  267. # 达到最大次数,写终止日志
  268. final_payload = {
  269. "phase": "LLM_CALL_FINAL_FAILURE",
  270. "attempts": attempts,
  271. "error_type": e.__class__.__name__,
  272. "error_message": str(e),
  273. }
  274. self._write_log('llm_final_error.log', final_payload)
  275. print("[LLM][Abort] 已耗尽所有重试,任务终止。", flush=True)
  276. raise # 抛出最后的异常
  277. # 计算下一次的 delay
  278. delay = default_delays[attempt - 1] if attempt - 1 < len(default_delays) else default_delays[-1]
  279. try:
  280. await asyncio.sleep(delay)
  281. except asyncio.CancelledError:
  282. raise
  283. continue
  284. # 若循环未 break (理论不会), 直接抛出
  285. if last_err and 'response' not in locals():
  286. raise last_err
  287. # ========== 提取 usage 信息 (兼容不同字段结构) ==========
  288. usage_raw: dict | None = None
  289. try:
  290. # LangChain 常见: response.response_metadata.token_usage
  291. resp_meta = getattr(response, 'response_metadata', None)
  292. if isinstance(resp_meta, dict):
  293. if 'token_usage' in resp_meta and isinstance(resp_meta['token_usage'], dict):
  294. usage_raw = resp_meta['token_usage']
  295. elif 'usage' in resp_meta and isinstance(resp_meta['usage'], dict):
  296. usage_raw = resp_meta['usage']
  297. if usage_raw is None:
  298. add_kwargs = getattr(response, 'additional_kwargs', None)
  299. if isinstance(add_kwargs, dict):
  300. # OpenAI 兼容格式: usage:{prompt_tokens,...}
  301. if 'usage' in add_kwargs and isinstance(add_kwargs['usage'], dict):
  302. usage_raw = add_kwargs['usage']
  303. except Exception:
  304. usage_raw = None
  305. usage_record = {}
  306. if usage_raw:
  307. # 标准化键名并提取
  308. ct = usage_raw.get('completion_tokens', 0)
  309. pt = usage_raw.get('prompt_tokens', 0)
  310. pch = usage_raw.get('prompt_cache_hit_tokens', 0)
  311. pcm = usage_raw.get('prompt_cache_miss_tokens', pt - pch) # 如果未提供,则计算
  312. tt = usage_raw.get('total_tokens', pt + ct)
  313. reasoning_tokens = 0
  314. if 'completion_tokens_details' in usage_raw and isinstance(usage_raw['completion_tokens_details'], dict):
  315. reasoning_tokens = usage_raw['completion_tokens_details'].get('reasoning_tokens', 0)
  316. # 计费规则 (元/百万 tokens)
  317. PRICE_CACHE_HIT = 0.5
  318. PRICE_CACHE_MISS = 4.0
  319. PRICE_OUTPUT = 12.0
  320. # 计算本次请求的费用
  321. cost = (
  322. (pch * PRICE_CACHE_HIT) +
  323. (pcm * PRICE_CACHE_MISS) +
  324. (ct * PRICE_OUTPUT)
  325. ) / 1_000_000
  326. # 累计
  327. self._usage_totals['completion_tokens'] += ct
  328. self._usage_totals['prompt_tokens'] += pt
  329. self._usage_totals['prompt_cache_hit_tokens'] += pch
  330. self._usage_totals['prompt_cache_miss_tokens'] += pcm
  331. self._usage_totals['total_tokens'] += tt
  332. self._usage_totals['completion_tokens_details']['reasoning_tokens'] += reasoning_tokens
  333. self._usage_totals['cost_yuan'] += cost
  334. usage_record = {
  335. 'completion_tokens': ct,
  336. 'prompt_tokens': pt,
  337. 'prompt_cache_hit_tokens': pch,
  338. 'prompt_cache_miss_tokens': pcm,
  339. 'total_tokens': tt,
  340. 'completion_tokens_details': {
  341. 'reasoning_tokens': reasoning_tokens
  342. },
  343. 'cost_yuan': cost
  344. }
  345. # 每次更新后,写入 summary_stats.log
  346. self._write_log('summary_stats.log', self._usage_totals)
  347. # 日志记录
  348. self.llm_call_index += 1
  349. now = datetime.now().strftime('%H%M%S')
  350. log_filename = f"{now}_LLM{self.llm_call_index}.log"
  351. # 整理 messages 为可序列化
  352. def serialize_msg(m: BaseMessage):
  353. base = {
  354. "type": m.__class__.__name__,
  355. "content": getattr(m, 'content', None)
  356. }
  357. if isinstance(m, AIMessage):
  358. base["tool_calls"] = getattr(m, 'tool_calls', None)
  359. if isinstance(m, ToolMessage):
  360. base["name"] = getattr(m, 'name', None)
  361. return base
  362. serialized_messages = [serialize_msg(m) for m in messages]
  363. # 最新输入消息(倒数第二个可能是 Human / Tool 等,最后一个是 response 前的 AI 触发?此处取倒数第二个作为当前触发上下文,若存在)
  364. current_trigger_message = serialize_msg(messages[-1]) if messages else None
  365. log_payload = {
  366. "phase": "LLM_CALL",
  367. "model": MODEL_NAME,
  368. "parameters": {
  369. "base_url": BASE_URL,
  370. "temperature": TEMPERATURE,
  371. "max_tokens": MAX_TOKENS,
  372. "timeout": TIME_OUT,
  373. },
  374. "current_message": current_trigger_message,
  375. "response": serialize_msg(response),
  376. "history_messages": serialized_messages,
  377. "usage": usage_record,
  378. "usage_totals_accumulated": self._usage_totals
  379. }
  380. self._write_log(log_filename, log_payload)
  381. # 返回更新后的消息列表
  382. return {"messages": [response]}
  383. async def _call_tool(self, state: State):
  384. """调用工具处理状态"""
  385. messages = state['messages']
  386. last_message = messages[-1]
  387. # 确保最后一条消息是AIMessage且有工具调用
  388. if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
  389. return {"messages": []}
  390. tool_calls = last_message.tool_calls
  391. results = []
  392. for tool_call in tool_calls:
  393. tool_name = tool_call['name']
  394. tool_args = tool_call['args']
  395. # 查找对应的工具实例
  396. tool_map = {t.name: t for t in self.mcp_tools}
  397. if tool_name not in tool_map:
  398. result = f"Error: 工具 {tool_name} 未找到"
  399. else:
  400. # 执行工具
  401. print(f"使用工具:{tool_name}")
  402. tool = tool_map[tool_name]
  403. try:
  404. # 直接传递原始参数,不做兼容/修改
  405. result = await tool.ainvoke(tool_args)
  406. print(f" 调用工具成功: {tool_name} 输出: {result}")
  407. except Exception as e:
  408. # 捕获所有异常,记录完整 traceback 到日志文件,并返回日志路径供上层 LLM 重试
  409. tb = traceback.format_exc()
  410. self.tool_call_index += 1
  411. now = datetime.now().strftime('%H%M%S')
  412. tool_log_filename = f"{now}_TOOL{self.tool_call_index}.log"
  413. log_payload = {
  414. "phase": "TOOL_CALL_EXCEPTION",
  415. "tool_name": tool_name,
  416. "args": tool_args,
  417. "error_type": e.__class__.__name__,
  418. "error_message": str(e),
  419. "traceback": tb
  420. }
  421. self._write_log(tool_log_filename, log_payload)
  422. # 将日志文件路径作为结果返回(相对路径)
  423. result = f"ERROR_LOG_PATH: {(self.session_log_dir / tool_log_filename).as_posix()}"
  424. # 记录工具日志
  425. self.tool_call_index += 1
  426. now = datetime.now().strftime('%H%M%S')
  427. tool_log_filename = f"{now}_TOOL{self.tool_call_index}.log"
  428. tool_log_payload = {
  429. "phase": "TOOL_CALL",
  430. "tool_name": tool_name,
  431. "args": tool_args,
  432. "result": str(result)
  433. }
  434. self._write_log(tool_log_filename, tool_log_payload)
  435. # 为每个工具调用生成一个ToolMessage
  436. results.append(
  437. ToolMessage(
  438. content=str(result),
  439. name=tool_name,
  440. tool_call_id=tool_call['id']
  441. )
  442. )
  443. # 返回更新后的消息列表 (不再包含自动成功总结逻辑)
  444. return {"messages": results}
  445. async def process_input(self, user_input: str, state: Optional[State] = None, system_prompt: str | None = None):
  446. """处理用户输入并返回更新后的状态"""
  447. if state is None:
  448. # 如果未显式传入,则动态加载并注入工作目录
  449. base_name = self._base_prompt_filename or 'core_task.md'
  450. sys_prompt_text = system_prompt if system_prompt is not None else self.prepare_system_prompt(base_name)
  451. state = State(messages=[SystemMessage(content=sys_prompt_text)])
  452. else:
  453. # 如果调用方已经提前构建了 state,但还未初始化日志,则此处补做
  454. if self.session_log_dir is None:
  455. if not self.root_log_dir.exists():
  456. try:
  457. self.root_log_dir.mkdir(parents=True, exist_ok=True)
  458. except Exception:
  459. pass
  460. self._init_session_logging()
  461. # 添加用户消息到状态
  462. state["messages"].append(HumanMessage(content=user_input))
  463. # 执行图
  464. invoke_config = {}
  465. if RECURSION_LIMIT is not None:
  466. invoke_config["recursion_limit"] = RECURSION_LIMIT
  467. result = await self.graph.ainvoke(state, config=invoke_config)
  468. # Client运行完后创建结束日志complate_message.log
  469. output_dict = {
  470. "user_input": result.get("user_input", ""),
  471. "user_id": result.get("user_id", ""),
  472. "messages": result.get("messages", "")
  473. }
  474. self._write_log_end("complate_message.log", output_dict)
  475. return result
  476. # ================= Prompt 动态注入 =================
  477. def prepare_system_prompt(self, prompt_file: str) -> str:
  478. """加载系统 prompt 并动态注入本次会话的工作目录说明。
  479. 规则:
  480. 1. 将所有出现的 './file_create/' 替换为 当前工作目录 (末尾带 '/')。
  481. 2. 支持占位符 '{{WORK_DIR}}' 被替换为当前工作目录。
  482. 3. 追加一段约束说明,强制 LLM 仅在该目录下写文件:<work_dir><module_name>/...
  483. """
  484. raw = load_prompt(prompt_file)
  485. work_dir_str = self.session_work_dir.as_posix() + "/"
  486. replaced = raw.replace('./file_create/', work_dir_str)
  487. replaced = replaced.replace('{{WORK_DIR}}', work_dir_str)
  488. # 规范 run_pika 命令里 --module-dir 参数为动态目录
  489. # 常见原始示例中使用: --module-dir ./file_create 或省略时间戳,因此替换这类片段
  490. replaced = replaced.replace('--module-dir ./file_create', f'--module-dir {work_dir_str.rstrip("/")}')
  491. # 如果 prompt 指出命令示例中直接写 test_example.py 路径,也替换为动态的
  492. replaced = replaced.replace(' ./file_create/test_example.py', f' {work_dir_str}test_example.py')
  493. appendix = (
  494. f"\n\n### 动态工作目录 (自动插入)\n"
  495. f"本次任务独立工作根目录: {work_dir_str}\n"
  496. "所有生成/修改文件必须位于该目录 (及其子目录)。禁止写入根仓库其它路径; 若需要读取日志/源码可只读不写。\n"
  497. "写入非该目录会被工具层直接拒绝。\n"
  498. "模块目录结构示例: <WORK_DIR><module_name>/<module_name>.pyi 与 C 实现文件。\n"
  499. f"运行构建示例命令: python run_pika.py --module <module_name> --module-dir {work_dir_str.rstrip('/')} {work_dir_str}test_example.py\n"\
  500. "\n### 环境限制\n"
  501. "1. 禁止使用 f-string 语法 (形如 f\"...{x}\").\n"
  502. "2. 禁止使用 round() 函数。\n"
  503. "违反上述任一会导致额外修补循环,必须一次性规避。\n"
  504. "\n### 成功判定与终止策略\n"
  505. "当首次在运行/自测输出中同时出现 'SELFTEST' 与 'OK' (或生成 [MODULE] 模块汇总块) 视为整体成功。随后立即: \n"
  506. "1. 输出 [SUMMARY] 段落(列出模块名/文件列表/步骤统计)。\n"
  507. "2. 不再提出新的工具调用或修改请求,直接结束。\n"
  508. "禁止在成功后继续追加改进操作; 改进建议只在 [SUMMARY] 里简述一行。\n"
  509. )
  510. return replaced + appendix
  511. # 主函数
  512. async def main():
  513. parser = argparse.ArgumentParser(description="Python->PikaPython 模块转换 Agent")
  514. parser.add_argument('--code', help='直接传入一段待转换的 Python 代码 (非交互模式)')
  515. parser.add_argument('--code-file', help='从文件读取待转换 Python 代码 (与 --code 互斥)')
  516. parser.add_argument('--prompt-file', default='core_task.md', help='指定使用的系统 prompt 文件名 (位于 prompt/ 下)')
  517. args = parser.parse_args()
  518. pyfig = pyfiglet.figlet_format("Python->C")
  519. print(pyfig)
  520. print("="*30)
  521. print("\033[1;33;40m llm驱动的python-C跨语言编译系统 (Pika集成路径)\033[0m")
  522. print("="*30)
  523. print("\n")
  524. # 装载系统 prompt
  525. try:
  526. system_prompt_text = load_prompt(args.prompt_file)
  527. except FileNotFoundError as e:
  528. print(f"[FATAL] Prompt 文件不存在: {e}")
  529. return
  530. transpiler = PythonCTranspiler()
  531. transpiler._base_prompt_filename = args.prompt_file
  532. await transpiler.initialize()
  533. # 初始 state
  534. # 使用动态注入后的 prompt
  535. dynamic_prompt = transpiler.prepare_system_prompt(args.prompt_file)
  536. state = State(messages=[SystemMessage(content=dynamic_prompt)])
  537. # 非交互一次性模式
  538. if args.code or args.code_file:
  539. start_time = datetime.now() # 记录开始时间
  540. if args.code and args.code_file:
  541. print('[ERROR] --code 与 --code-file 不能同时使用')
  542. return
  543. # 如果指定了代码或代码文件,则自动推断模块名并注入
  544. if args.code_file:
  545. code_path = Path(args.code_file)
  546. try:
  547. code_text = code_path.read_text(encoding='utf-8')
  548. except Exception as e:
  549. print(f'[ERROR] 读取代码文件失败: {e}')
  550. return
  551. inferred_module = code_path.stem.replace('-', '_').replace(' ', '_')
  552. hint = f"# MODULE_NAME_HINT: {inferred_module}"
  553. # 避免重复重复注入
  554. if not code_text.lstrip().startswith('# MODULE_NAME_HINT:'):
  555. code_text = hint + code_text
  556. else:
  557. code_text = args.code
  558. result_state = await transpiler.process_input(code_text, state, system_prompt=system_prompt_text)
  559. if result_state["messages"] and isinstance(result_state["messages"][-1], AIMessage):
  560. print(f"AI: {result_state['messages'][-1].content}")
  561. # 打印详细的 Token 使用量和费用
  562. end_time = datetime.now()
  563. duration = end_time - start_time
  564. total_seconds = int(duration.total_seconds())
  565. hours, remainder = divmod(total_seconds, 3600)
  566. minutes, seconds = divmod(remainder, 60)
  567. cache_hit_tokens = transpiler._usage_totals.get('prompt_cache_hit_tokens', 0)
  568. cache_miss_tokens = transpiler._usage_totals.get('prompt_cache_miss_tokens', 0)
  569. completion_tokens = transpiler._usage_totals.get('completion_tokens', 0)
  570. total_cost = transpiler._usage_totals.get('cost_yuan', 0.0)
  571. total_tool_calls = transpiler.tool_call_index
  572. # 计费规则 (元/百万 tokens)
  573. PRICE_CACHE_HIT = 0.5
  574. PRICE_CACHE_MISS = 4.0
  575. PRICE_OUTPUT = 12.0
  576. cost_cache_hit = (cache_hit_tokens * PRICE_CACHE_HIT) / 1_000_000
  577. cost_cache_miss = (cache_miss_tokens * PRICE_CACHE_MISS) / 1_000_000
  578. cost_completion = (completion_tokens * PRICE_OUTPUT) / 1_000_000
  579. # 费用占比(防止除零)
  580. denom = total_cost if total_cost > 0 else (cost_cache_hit + cost_cache_miss + cost_completion)
  581. if denom == 0:
  582. pct_cache_hit = pct_cache_miss = pct_completion = 0.0
  583. else:
  584. pct_cache_hit = cost_cache_hit / denom * 100
  585. pct_cache_miss = cost_cache_miss / denom * 100
  586. pct_completion = cost_completion / denom * 100
  587. print(f"\n{'='*30}\nUsage & Stats Summary:\n")
  588. print(f" - Cache Input: {cache_hit_tokens / 1000:.1f}k tokens ({cost_cache_hit:.3f} 元, {pct_cache_hit:.2f}%)")
  589. print(f" - Fresh Input: {cache_miss_tokens / 1000:.1f}k tokens ({cost_cache_miss:.3f} 元, {pct_cache_miss:.2f}%)")
  590. print(f" - Output: {completion_tokens / 1000:.1f}k tokens ({cost_completion:.3f} 元, {pct_completion:.2f}%)")
  591. print(f" - Tool Calls: {total_tool_calls}")
  592. print(f" - Total Time: {hours}h {minutes}m {seconds}s")
  593. print(f"\n{'='*30}")
  594. print(f"Total Cost: {total_cost:.3f} 元\n{'='*30}")
  595. # 退出前关闭资源
  596. await transpiler.close()
  597. return
  598. # 若不 exit, 继续进入交互
  599. state = result_state
  600. # 打印详细的 Token 使用量和费用
  601. end_time = datetime.now()
  602. duration = end_time - start_time
  603. total_seconds = int(duration.total_seconds())
  604. hours, remainder = divmod(total_seconds, 3600)
  605. minutes, seconds = divmod(remainder, 60)
  606. cache_hit_tokens = transpiler._usage_totals.get('prompt_cache_hit_tokens', 0)
  607. cache_miss_tokens = transpiler._usage_totals.get('prompt_cache_miss_tokens', 0)
  608. completion_tokens = transpiler._usage_totals.get('completion_tokens', 0)
  609. total_cost = transpiler._usage_totals.get('cost_yuan', 0.0)
  610. total_tool_calls = transpiler.tool_call_index
  611. # 计费规则 (元/百万 tokens)
  612. PRICE_CACHE_HIT = 0.5
  613. PRICE_CACHE_MISS = 4.0
  614. PRICE_OUTPUT = 12.0
  615. cost_cache_hit = (cache_hit_tokens * PRICE_CACHE_HIT) / 1_000_000
  616. cost_cache_miss = (cache_miss_tokens * PRICE_CACHE_MISS) / 1_000_000
  617. cost_completion = (completion_tokens * PRICE_OUTPUT) / 1_000_000
  618. print(f"\n{'='*30}\nUsage & Stats Summary:\n")
  619. print(f" - Cache Input: {cache_hit_tokens / 1000:.2f}k tokens ({cost_cache_hit:.6f} 元)")
  620. print(f" - Fresh Input: {cache_miss_tokens / 1000:.2f}k tokens ({cost_cache_miss:.6f} 元)")
  621. print(f" - Output: {completion_tokens / 1000:.2f}k tokens ({cost_completion:.6f} 元)")
  622. print(f" - Tool Calls: {total_tool_calls}")
  623. print(f" - Total Time: {hours}h {minutes}m {seconds}s")
  624. print(f"\n{'='*30}")
  625. print(f"Total Cost: {total_cost:.6f} 元\n{'='*30}")
  626. # 退出前关闭资源
  627. await transpiler.close()
  628. return
  629. # 若不 exit, 继续进入交互
  630. state = result_state
  631. # 交互循环
  632. while True:
  633. try:
  634. user_input = input("你: ").strip()
  635. except EOFError:
  636. break
  637. if user_input.lower() in ["退出", "exit", "quit"]:
  638. print("再见!")
  639. break
  640. if not user_input:
  641. continue
  642. state = await transpiler.process_input(user_input, state, system_prompt=system_prompt_text)
  643. if state["messages"] and isinstance(state["messages"][-1], AIMessage):
  644. print(f"AI: {state['messages'][-1].content}")
  645. if __name__ == "__main__":
  646. asyncio.run(main())
  647. # 简化接口函数
  648. def initialize_agent():
  649. """初始化agent(单例模式会自动处理)"""
  650. return ChatAgent()
  651. def get_agent_response(agent_instance, user_input):
  652. return agent_instance.process_message(user_input)
  653. def get_agent_status(agent_instance):
  654. return {
  655. 'initialized_time': agent_instance.initialized_time,
  656. 'total_conversations': len(agent_instance.conversation_history) // 2,
  657. 'model_loaded': agent_instance.model_loaded
  658. }