mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 19:58:27 +00:00
多场景对话架构一期0525
This commit is contained in:
parent
6ca0385358
commit
57519b9006
@ -83,7 +83,7 @@ class ChatWithDb(BaseChat):
|
|||||||
"stop": self.prompt_template.sep,
|
"stop": self.prompt_template.sep,
|
||||||
}
|
}
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
### 走非流式的模型服务接口
|
### 走非流式的模型服务接口
|
||||||
|
|
||||||
@ -94,21 +94,7 @@ class ChatWithDb(BaseChat):
|
|||||||
|
|
||||||
result = self.database.run(self.db_connect, prompt_define_response.sql)
|
result = self.database.run(self.db_connect, prompt_define_response.sql)
|
||||||
|
|
||||||
# # TODO - TEST
|
|
||||||
# resp_test = {
|
|
||||||
# "SQL": "select * from users",
|
|
||||||
# "thoughts": {
|
|
||||||
# "text": "thought",
|
|
||||||
# "reasoning": "reasoning",
|
|
||||||
# "plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
|
||||||
# "criticism": "constructive self-criticism",
|
|
||||||
# "speak": "thoughts summary to say to user"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# sql_action = SqlAction(**resp_test)
|
|
||||||
# self.current_message.add_ai_message(json.dumps(sql_action._asdict()))
|
|
||||||
# result = self.database.run(self.db_connect, sql_action.SQL)
|
|
||||||
if hasattr(prompt_define_response, 'thoughts'):
|
if hasattr(prompt_define_response, 'thoughts'):
|
||||||
if prompt_define_response.thoughts.get("speak"):
|
if prompt_define_response.thoughts.get("speak"):
|
||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
@ -126,7 +112,7 @@ class ChatWithDb(BaseChat):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
self.current_message.add_view_message(f"ERROR:{str(e)}!{ai_response_text}")
|
self.current_message.add_view_message(f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """)
|
||||||
### 对话记录存储
|
### 对话记录存储
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
|
@ -1,26 +1,20 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
NamedTuple
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
class SqlAction(NamedTuple):
|
class SqlAction(NamedTuple):
|
||||||
sql: str
|
sql: str
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
class DbChatOutputParser(BaseOutputParser):
|
class DbChatOutputParser(BaseOutputParser):
|
||||||
|
|
||||||
def __init__(self, sep:str, is_stream_out: bool):
|
def __init__(self, sep:str, is_stream_out: bool):
|
||||||
@ -43,9 +37,17 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
if cleaned_output.endswith("```"):
|
if cleaned_output.endswith("```"):
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
|
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||||
|
logger.info("illegal json processing")
|
||||||
|
json_pattern = r'{(.+?)}'
|
||||||
|
m = re.search(json_pattern, cleaned_output)
|
||||||
|
if m:
|
||||||
|
cleaned_output = m.group(0)
|
||||||
|
else:
|
||||||
|
raise ValueError("model server out not fllow the prompt!")
|
||||||
|
|
||||||
response = json.loads(cleaned_output)
|
response = json.loads(cleaned_output)
|
||||||
sql, thoughts = response["sql"], response["thoughts"]
|
sql, thoughts = response["sql"], response["thoughts"]
|
||||||
|
|
||||||
return SqlAction(sql, thoughts)
|
return SqlAction(sql, thoughts)
|
||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user