From fe662cec5e22d38c40aaa51ff6e31e8d37565688 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Thu, 29 Jun 2023 17:24:18 +0800 Subject: [PATCH] WEB API independent --- pilot/openapi/api_v1/api_v1.py | 5 +++-- pilot/scene/base_chat.py | 4 ++-- pilot/scene/chat_db/auto_execute/chat.py | 1 + pilot/scene/chat_db/auto_execute/example.py | 10 ++++------ pilot/scene/chat_db/auto_execute/out_parser.py | 18 ++++++++++++------ pilot/scene/chat_execution/chat.py | 1 + pilot/scene/chat_normal/out_parser.py | 1 + 7 files changed, 24 insertions(+), 16 deletions(-) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 740af4f05..17d9b88fe 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -252,7 +252,6 @@ async def stream_generator(chat): for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) - chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" await asyncio.sleep(0.1) @@ -260,11 +259,13 @@ async def stream_generator(chat): for chunk in model_response: if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) - chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" await asyncio.sleep(0.1) + + chat.current_message.add_ai_message(msg) + chat.current_message.add_view_message(msg) chat.memory.append(chat.current_message) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 376d99ef8..6c005dbf3 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -148,8 +148,8 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### store current conversation - self.memory.append(self.current_message) + ### store current conversation + self.memory.append(self.current_message) def nostream_call(self): payload = self.__call_base() diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 613c660c7..5b5998024 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -54,4 +54,5 @@ class ChatWithDbAutoExecute(BaseChat): return input_values def do_action(self, prompt_response): + print(f"do_action:{prompt_response}") return self.database.run(self.db_connect, prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index a2a01b44d..b4c248d65 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -4,14 +4,13 @@ from pilot.common.schema import ExampleType EXAMPLES = [ { "messages": [ - {"type": "human", "data": {"content": "查询xxx", "example": True}}, + {"type": "human", "data": {"content": "查询用户test1所在的城市", "example": True}}, { "type": "ai", "data": { "content": """{ \"thoughts\": \"thought text\", - \"speak\": \"thoughts summary to say to user\", - \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, + \"sql\": \"SELECT city FROM users where user_name='test1'\", }""", "example": True, }, @@ -20,14 +19,13 @@ EXAMPLES = [ }, { "messages": [ - {"type": "human", "data": {"content": "查询xxx", "example": True}}, + {"type": "human", "data": {"content": "查询成都的用户的订单信息", "example": True}}, { "type": "ai", "data": { "content": """{ \"thoughts\": \"thought text\", - \"speak\": \"thoughts summary to say to user\", - \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, + \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", }""", "example": True, }, diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 94cc6ea9e..eaa45498c 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -6,8 +6,9 @@ import pandas as pd from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T from pilot.configs.model_config import LOGDIR +from pilot.configs.config import Config - +CFG = Config() class SqlAction(NamedTuple): sql: str thoughts: Dict @@ -32,11 +33,16 @@ class DbChatOutputParser(BaseOutputParser): if len(data) <= 1: data.insert(0, ["result"]) df = pd.DataFrame(data[1:], columns=data[0]) - table_style = """""" - html_table = df.to_html(index=False, escape=False) - html = f"{table_style}{html_table}" + if CFG.NEW_SERVER_MODE: + html = df.to_html(index=False, escape=False, sparsify=False) + html = ''.join(html.split()) + else: + table_style = """""" + html_table = df.to_html(index=False, escape=False) + html = f"{table_style}{html_table}" + view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") return view_text diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 97646c299..1c3735ac3 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -63,6 +63,7 @@ class ChatWithPlugin(BaseChat): return input_values def do_action(self, prompt_response): + print(f"do_action:{prompt_response}") ## plugin command run return execute_command( str(prompt_response.command.get("name")), diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py index e5edc9b20..112974176 100644 --- a/pilot/scene/chat_normal/out_parser.py +++ b/pilot/scene/chat_normal/out_parser.py @@ -12,6 +12,7 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class NormalChatOutputParser(BaseOutputParser): + def parse_prompt_response(self, model_out_text) -> T: return model_out_text