WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-29 17:24:18 +08:00
parent 7208dd6c88
commit fe662cec5e
7 changed files with 24 additions and 16 deletions

View File

@ -252,7 +252,6 @@ async def stream_generator(chat):
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n" yield f"data:{msg}\n\n"
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -260,11 +259,13 @@ async def stream_generator(chat):
for chunk in model_response: for chunk in model_response:
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n" yield f"data:{msg}\n\n"
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)

View File

@ -148,8 +148,8 @@ class BaseChat(ABC):
self.current_message.add_view_message( self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
### store current conversation ### store current conversation
self.memory.append(self.current_message) self.memory.append(self.current_message)
def nostream_call(self): def nostream_call(self):
payload = self.__call_base() payload = self.__call_base()

View File

@ -54,4 +54,5 @@ class ChatWithDbAutoExecute(BaseChat):
return input_values return input_values
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(self.db_connect, prompt_response.sql) return self.database.run(self.db_connect, prompt_response.sql)

View File

@ -4,14 +4,13 @@ from pilot.common.schema import ExampleType
EXAMPLES = [ EXAMPLES = [
{ {
"messages": [ "messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}}, {"type": "human", "data": {"content": "查询用户test1所在的城市", "example": True}},
{ {
"type": "ai", "type": "ai",
"data": { "data": {
"content": """{ "content": """{
\"thoughts\": \"thought text\", \"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\", \"sql\": \"SELECT city FROM users where user_name='test1'\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""", }""",
"example": True, "example": True,
}, },
@ -20,14 +19,13 @@ EXAMPLES = [
}, },
{ {
"messages": [ "messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}}, {"type": "human", "data": {"content": "查询成都的用户的订单信息", "example": True}},
{ {
"type": "ai", "type": "ai",
"data": { "data": {
"content": """{ "content": """{
\"thoughts\": \"thought text\", \"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\", \"sql\": \"SELECT b.* FROM users a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""", }""",
"example": True, "example": True,
}, },

View File

@ -6,8 +6,9 @@ import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
class SqlAction(NamedTuple): class SqlAction(NamedTuple):
sql: str sql: str
thoughts: Dict thoughts: Dict
@ -32,11 +33,16 @@ class DbChatOutputParser(BaseOutputParser):
if len(data) <= 1: if len(data) <= 1:
data.insert(0, ["result"]) data.insert(0, ["result"])
df = pd.DataFrame(data[1:], columns=data[0]) df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style> if CFG.NEW_SERVER_MODE:
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444} html = df.to_html(index=False, escape=False, sparsify=False)
</style>""" html = ''.join(html.split())
html_table = df.to_html(index=False, escape=False) else:
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>" table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text return view_text

View File

@ -63,6 +63,7 @@ class ChatWithPlugin(BaseChat):
return input_values return input_values
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
## plugin command run ## plugin command run
return execute_command( return execute_command(
str(prompt_response.command.get("name")), str(prompt_response.command.get("name")),

View File

@ -12,6 +12,7 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text