mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
fix plugin mode bug;Optimize the parsing logic for model response
This commit is contained in:
parent
4e5ce4d98b
commit
18bacbd7f7
@ -15,6 +15,7 @@ import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
@ -82,7 +83,7 @@ def load_native_plugins(cfg: Config):
|
||||
headers={'Authorization': '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'})
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
@ -111,7 +112,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
|
||||
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
|
||||
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
|
||||
|
@ -88,7 +88,6 @@ class Config(metaclass=Singleton):
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
|
||||
|
@ -13,8 +13,18 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
|
||||
# 获取当前工作目录
|
||||
current_directory = os.getcwd()
|
||||
print("当前工作目录:", current_directory)
|
||||
|
||||
# 设置当前工作目录
|
||||
new_directory = PILOT_PATH
|
||||
os.chdir(new_directory)
|
||||
print("新的工作目录:", os.getcwd())
|
||||
|
||||
DEVICE = (
|
||||
"cuda"
|
||||
|
@ -9,46 +9,66 @@ from pyecharts import options as opts
|
||||
|
||||
CFG = Config()
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# # 创建连接池
|
||||
# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
|
||||
#
|
||||
# # 从连接池中获取连接
|
||||
#
|
||||
#
|
||||
# # 归还连接到连接池中
|
||||
#
|
||||
# # 执行SQL语句并将结果转化为DataFrame
|
||||
# query = "SELECT * FROM users"
|
||||
# df = pd.read_sql(query, engine.connect())
|
||||
# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
|
||||
# # 导出为HTML文件
|
||||
# with open('report.html', 'w') as f:
|
||||
# f.write(df.style.render())
|
||||
#
|
||||
# # # 设置中文字体
|
||||
# # font = FontProperties(fname='SimHei.ttf', size=14)
|
||||
# #
|
||||
# # colors = np.random.rand(df.shape[0])
|
||||
# # df.plot.scatter(x='city', y='user_name', c=colors)
|
||||
# # plt.show()
|
||||
#
|
||||
# # 查看DataFrame
|
||||
# print(df.head())
|
||||
#
|
||||
#
|
||||
# # 创建数据
|
||||
# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
# y_data = [820, 932, 901, 934, 1290, 1330, 1320]
|
||||
#
|
||||
# # 生成图表
|
||||
# bar = (
|
||||
# Bar()
|
||||
# .add_xaxis(x_data)
|
||||
# .add_yaxis("销售额", y_data)
|
||||
# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
|
||||
# )
|
||||
#
|
||||
# # 生成HTML文件
|
||||
# bar.render('report.html')
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建连接池
|
||||
engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
|
||||
def __extract_json(s):
|
||||
i = s.index('{')
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
|
||||
# 从连接池中获取连接
|
||||
|
||||
|
||||
# 归还连接到连接池中
|
||||
|
||||
# 执行SQL语句并将结果转化为DataFrame
|
||||
query = "SELECT * FROM users"
|
||||
df = pd.read_sql(query, engine.connect())
|
||||
df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
|
||||
# 导出为HTML文件
|
||||
with open('report.html', 'w') as f:
|
||||
f.write(df.style.render())
|
||||
|
||||
# # 设置中文字体
|
||||
# font = FontProperties(fname='SimHei.ttf', size=14)
|
||||
#
|
||||
# colors = np.random.rand(df.shape[0])
|
||||
# df.plot.scatter(x='city', y='user_name', c=colors)
|
||||
# plt.show()
|
||||
|
||||
# 查看DataFrame
|
||||
print(df.head())
|
||||
|
||||
|
||||
# 创建数据
|
||||
x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
y_data = [820, 932, 901, 934, 1290, 1330, 1320]
|
||||
|
||||
# 生成图表
|
||||
bar = (
|
||||
Bar()
|
||||
.add_xaxis(x_data)
|
||||
.add_yaxis("销售额", y_data)
|
||||
.set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
|
||||
)
|
||||
|
||||
# 生成HTML文件
|
||||
bar.render('report.html')
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
BIN
pilot/fonts/SimHei.ttf
Normal file
BIN
pilot/fonts/SimHei.ttf
Normal file
Binary file not shown.
BIN
pilot/mock_datas/db-gpt-test.db
Normal file
BIN
pilot/mock_datas/db-gpt-test.db
Normal file
Binary file not shown.
@ -95,7 +95,6 @@ class BaseOutputParser(ABC):
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
text = text.lower()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj["response"]
|
||||
@ -111,7 +110,9 @@ class BaseOutputParser(ABC):
|
||||
last_index = i
|
||||
ai_response = tmpResp[last_index]
|
||||
ai_response = ai_response.replace("assistant:", "")
|
||||
ai_response = ai_response.replace("\n", "")
|
||||
ai_response = ai_response.replace("Assistant:", "")
|
||||
ai_response = ai_response.replace("ASSISTANT:", "")
|
||||
ai_response = ai_response.replace("\n", " ")
|
||||
ai_response = ai_response.replace("\_", "_")
|
||||
ai_response = ai_response.replace("\*", "*")
|
||||
print("un_stream ai response:", ai_response)
|
||||
@ -119,6 +120,19 @@ class BaseOutputParser(ABC):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def __extract_json(slef, s):
|
||||
i = s.index('{')
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == '}':
|
||||
count -= 1
|
||||
elif c == '{':
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert (count == 0) # 检查是否找到最后一个'}'
|
||||
return s[i:j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
@ -129,8 +143,8 @@ class BaseOutputParser(ABC):
|
||||
|
||||
"""
|
||||
cleaned_output = model_out_text.rstrip()
|
||||
# if "```json" in cleaned_output:
|
||||
# _, cleaned_output = cleaned_output.split("```json")
|
||||
if "```json" in cleaned_output:
|
||||
_, cleaned_output = cleaned_output.split("```json")
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
@ -142,18 +156,12 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing")
|
||||
json_pattern = r"{(.+?)}"
|
||||
m = re.search(json_pattern, cleaned_output)
|
||||
if m:
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", "")
|
||||
.replace("\\n", "")
|
||||
.replace("\\", "")
|
||||
.replace("\\", "")
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
|
@ -78,11 +78,8 @@ class ChatWithPlugin(BaseChat):
|
||||
super().chat_show()
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
if list:
|
||||
separator = "\n"
|
||||
return separator.join(list)
|
||||
else:
|
||||
return ""
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
|
||||
def generate(self, p) -> str:
|
||||
return super().generate(p)
|
||||
|
@ -14,20 +14,24 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str
|
||||
reasoning: str
|
||||
thoughts: str
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
command, thoughts, speak, reasoning = (
|
||||
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||
print(clean_json_str)
|
||||
try:
|
||||
response = json.loads(clean_json_str)
|
||||
except Exception as e:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
command, thoughts, speak = (
|
||||
response["command"],
|
||||
response["thoughts"],
|
||||
response["speak"],
|
||||
response["reasoning"],
|
||||
response["speak"]
|
||||
)
|
||||
return PluginAction(command, speak, reasoning, thoughts)
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
|
@ -10,7 +10,7 @@ from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||
|
||||
PROMPT_SUFFIX = """
|
||||
Goals:
|
||||
@ -20,25 +20,22 @@ Goals:
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Constraints:
|
||||
Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
Reflect on past decisions and strategies to refine your approach.
|
||||
Constructively self-criticize your big-picture behavior constantly.
|
||||
{constraints}
|
||||
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
{constraints}
|
||||
|
||||
Commands:
|
||||
{commands_infos}
|
||||
{commands_infos}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
|
||||
PROMPT_RESPONSE = """
|
||||
Please response strictly according to the following json format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = {
|
||||
"thoughts": "thought text",
|
||||
"reasoning": "reasoning",
|
||||
"speak": "thoughts summary to say to user",
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
BIN
pilot/server/SimHei.ttf
Normal file
BIN
pilot/server/SimHei.ttf
Normal file
Binary file not shown.
@ -4,6 +4,8 @@ import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -3,6 +3,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.plugins import (
|
||||
denylist_allowlist_check,
|
||||
inspect_zip_for_modules,
|
||||
|
Loading…
Reference in New Issue
Block a user