From 57519b9006604a1569e1787bc066bbef305bb71e Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 25 May 2023 19:50:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E5=9C=BA=E6=99=AF=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E4=B8=80=E6=9C=9F0525?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pilot/scene/chat_db/chat.py | 20 +++----------------- pilot/scene/chat_db/out_parser.py | 26 ++++++++++++++------------ 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py index c86810f2b..72ad64508 100644 --- a/pilot/scene/chat_db/chat.py +++ b/pilot/scene/chat_db/chat.py @@ -83,7 +83,7 @@ class ChatWithDb(BaseChat): "stop": self.prompt_template.sep, } logger.info(f"Requert: \n{payload}") - + ai_response_text = "" try: ### 走非流式的模型服务接口 @@ -94,21 +94,7 @@ class ChatWithDb(BaseChat): result = self.database.run(self.db_connect, prompt_define_response.sql) - # # TODO - TEST - # resp_test = { - # "SQL": "select * from users", - # "thoughts": { - # "text": "thought", - # "reasoning": "reasoning", - # "plan": "- short bulleted\n- list that conveys\n- long-term plan", - # "criticism": "constructive self-criticism", - # "speak": "thoughts summary to say to user" - # } - # } - # - # sql_action = SqlAction(**resp_test) - # self.current_message.add_ai_message(json.dumps(sql_action._asdict())) - # result = self.database.run(self.db_connect, sql_action.SQL) + if hasattr(prompt_define_response, 'thoughts'): if prompt_define_response.thoughts.get("speak"): self.current_message.add_view_message( @@ -126,7 +112,7 @@ class ChatWithDb(BaseChat): except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) - self.current_message.add_view_message(f"ERROR:{str(e)}!{ai_response_text}") + self.current_message.add_view_message(f"""ERROR!{str(e)}\n {ai_response_text} """) ### 对话记录存储 self.memory.append(self.current_message) diff --git a/pilot/scene/chat_db/out_parser.py b/pilot/scene/chat_db/out_parser.py index 6c511508d..1d2597f57 100644 --- a/pilot/scene/chat_db/out_parser.py +++ b/pilot/scene/chat_db/out_parser.py @@ -1,26 +1,20 @@ import json +import re from abc import ABC, abstractmethod from typing import ( - Any, Dict, - Generic, - List, - NamedTuple, - Optional, - Sequence, - TypeVar, - Union, + NamedTuple ) import pandas as pd - +from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T - +from pilot.configs.model_config import LOGDIR class SqlAction(NamedTuple): sql: str thoughts: Dict - +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class DbChatOutputParser(BaseOutputParser): def __init__(self, sep:str, is_stream_out: bool): @@ -43,9 +37,17 @@ class DbChatOutputParser(BaseOutputParser): if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() + if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): + logger.info("illegal json processing") + json_pattern = r'{(.+?)}' + m = re.search(json_pattern, cleaned_output) + if m: + cleaned_output = m.group(0) + else: + raise ValueError("model server out not fllow the prompt!") + response = json.loads(cleaned_output) sql, thoughts = response["sql"], response["thoughts"] - return SqlAction(sql, thoughts) def parse_view_response(self, speak, data) -> str: