mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 03:44:14 +00:00
多场景对话架构一期0525
This commit is contained in:
parent
6ca0385358
commit
57519b9006
@ -83,7 +83,7 @@ class ChatWithDb(BaseChat):
|
||||
"stop": self.prompt_template.sep,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
ai_response_text = ""
|
||||
try:
|
||||
### 走非流式的模型服务接口
|
||||
|
||||
@ -94,21 +94,7 @@ class ChatWithDb(BaseChat):
|
||||
|
||||
result = self.database.run(self.db_connect, prompt_define_response.sql)
|
||||
|
||||
# # TODO - TEST
|
||||
# resp_test = {
|
||||
# "SQL": "select * from users",
|
||||
# "thoughts": {
|
||||
# "text": "thought",
|
||||
# "reasoning": "reasoning",
|
||||
# "plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
# "criticism": "constructive self-criticism",
|
||||
# "speak": "thoughts summary to say to user"
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# sql_action = SqlAction(**resp_test)
|
||||
# self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
|
||||
# result = self.database.run(self.db_connect, sql_action.SQL)
|
||||
|
||||
if hasattr(prompt_define_response, 'thoughts'):
|
||||
if prompt_define_response.thoughts.get("speak"):
|
||||
self.current_message.add_view_message(
|
||||
@ -126,7 +112,7 @@ class ChatWithDb(BaseChat):
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
self.current_message.add_view_message(f"ERROR:{str(e)}!{ai_response_text}")
|
||||
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
|
||||
### 对话记录存储
|
||||
self.memory.append(self.current_message)
|
||||
|
||||
|
@ -1,26 +1,20 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
NamedTuple
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
class SqlAction(NamedTuple):
|
||||
sql: str
|
||||
thoughts: Dict
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
def __init__(self, sep:str, is_stream_out: bool):
|
||||
@ -43,9 +37,17 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing")
|
||||
json_pattern = r'{(.+?)}'
|
||||
m = re.search(json_pattern, cleaned_output)
|
||||
if m:
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
response = json.loads(cleaned_output)
|
||||
sql, thoughts = response["sql"], response["thoughts"]
|
||||
|
||||
return SqlAction(sql, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user