多场景对话架构一期0525

This commit is contained in:
yhjun1026 2023-05-25 19:50:10 +08:00
parent 6ca0385358
commit 57519b9006
2 changed files with 17 additions and 29 deletions

View File

@ -83,7 +83,7 @@ class ChatWithDb(BaseChat):
"stop": self.prompt_template.sep,
}
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
### 走非流式的模型服务接口
@ -94,21 +94,7 @@ class ChatWithDb(BaseChat):
result = self.database.run(self.db_connect, prompt_define_response.sql)
# # TODO - TEST
# resp_test = {
# "SQL": "select * from users",
# "thoughts": {
# "text": "thought",
# "reasoning": "reasoning",
# "plan": "- short bulleted\n- list that conveys\n- long-term plan",
# "criticism": "constructive self-criticism",
# "speak": "thoughts summary to say to user"
# }
# }
#
# sql_action = SqlAction(**resp_test)
# self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
# result = self.database.run(self.db_connect, sql_action.SQL)
if hasattr(prompt_define_response, 'thoughts'):
if prompt_define_response.thoughts.get("speak"):
self.current_message.add_view_message(
@ -126,7 +112,7 @@ class ChatWithDb(BaseChat):
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(f"ERROR:{str(e)}!{ai_response_text}")
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
### 对话记录存储
self.memory.append(self.current_message)

View File

@ -1,26 +1,20 @@
import json
import re
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
NamedTuple
)
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep:str, is_stream_out: bool):
@ -43,9 +37,17 @@ class DbChatOutputParser(BaseOutputParser):
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing")
json_pattern = r'{(.+?)}'
m = re.search(json_pattern, cleaned_output)
if m:
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
response = json.loads(cleaned_output)
sql, thoughts = response["sql"], response["thoughts"]
return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str: