mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
feat(ChatDB): ChatDB Use fintune model
1.Compatible with community pure sql output model
This commit is contained in:
parent
1acc8b0134
commit
fa06ba5bf6
@ -372,8 +372,8 @@ class ApiCall:
|
||||
param["type"] = api_status.name
|
||||
if api_status.args:
|
||||
param["sql"] = api_status.args["sql"]
|
||||
if api_status.err_msg:
|
||||
param["err_msg"] = api_status.err_msg
|
||||
# if api_status.err_msg:
|
||||
# param["err_msg"] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param["data"] = api_status.api_result
|
||||
|
@ -153,11 +153,21 @@ class LLMModelAdaper:
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
can_use_system = ""
|
||||
if system_messages:
|
||||
# TODO vicuna 兼容 测试完放弃
|
||||
if len(system_messages) > 1:
|
||||
can_use_system = system_messages[0]
|
||||
conv[-1][0][-1] =system_messages[-1]
|
||||
elif len(system_messages) == 1:
|
||||
conv[-1][0][-1] = system_messages[-1]
|
||||
|
||||
if isinstance(conv, Conversation):
|
||||
conv.set_system_message("".join(system_messages))
|
||||
conv.set_system_message(can_use_system)
|
||||
else:
|
||||
conv.update_system_message("".join(system_messages))
|
||||
conv.update_system_message(can_use_system)
|
||||
|
||||
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
@ -26,6 +26,42 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
return res.json().get("access_token")
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) >0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@ -40,8 +76,9 @@ def wenxin_generate_stream(
|
||||
if not model_version:
|
||||
yield f"Unsupport model version {model_name}"
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
keys:[] = model_params.proxy_api_key.split(";")
|
||||
proxy_api_key = keys[0]
|
||||
proxy_api_secret = keys[1]
|
||||
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
@ -51,40 +88,43 @@ def wenxin_generate_stream(
|
||||
if not access_token:
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
|
||||
history = []
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
system = ""
|
||||
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
role_define = messages.pop(0)
|
||||
system = role_define.content
|
||||
else:
|
||||
message = messages.pop(0)
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# system = ""
|
||||
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
# role_define = messages.pop(0)
|
||||
# system = role_define.content
|
||||
# else:
|
||||
# message = messages.pop(0)
|
||||
# if message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# temp_his = history[::-1]
|
||||
temp_his = history
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
# for message in messages:
|
||||
# if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# # elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# # history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.AI:
|
||||
# history.append({"role": "assistant", "content": message.content})
|
||||
# else:
|
||||
# pass
|
||||
#
|
||||
# # temp_his = history[::-1]
|
||||
# temp_his = history
|
||||
# last_user_input = None
|
||||
# for m in temp_his:
|
||||
# if m["role"] == "user":
|
||||
# last_user_input = m
|
||||
# break
|
||||
#
|
||||
# if last_user_input:
|
||||
# history.remove(last_user_input)
|
||||
# history.append(last_user_input)
|
||||
#
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
system = ""
|
||||
if systems and len(systems)>0:
|
||||
system = systems[0]
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system,
|
||||
|
@ -8,6 +8,43 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) >0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@ -22,40 +59,40 @@ def zhipu_generate_stream(
|
||||
import zhipuai
|
||||
|
||||
zhipuai.api_key = proxy_api_key
|
||||
history = []
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
system = ""
|
||||
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
role_define = messages.pop(0)
|
||||
system = role_define.content
|
||||
else:
|
||||
message = messages.pop(0)
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# system = ""
|
||||
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
# role_define = messages.pop(0)
|
||||
# system = role_define.content
|
||||
# else:
|
||||
# message = messages.pop(0)
|
||||
# if message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# temp_his = history[::-1]
|
||||
temp_his = history
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
# for message in messages:
|
||||
# if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# # elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# # history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.AI:
|
||||
# history.append({"role": "assistant", "content": message.content})
|
||||
# else:
|
||||
# pass
|
||||
#
|
||||
# # temp_his = history[::-1]
|
||||
# temp_his = history
|
||||
# last_user_input = None
|
||||
# for m in temp_his:
|
||||
# if m["role"] == "user":
|
||||
# last_user_input = m
|
||||
# break
|
||||
#
|
||||
# if last_user_input:
|
||||
# history.remove(last_user_input)
|
||||
# history.append(last_user_input)
|
||||
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
prompt=history,
|
||||
|
@ -209,6 +209,8 @@ class BaseChat(ABC):
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
msg =""
|
||||
view_msg=""
|
||||
async for output in worker_manager.generate_stream(payload):
|
||||
### Plug-in research in result generation
|
||||
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
|
@ -30,13 +30,13 @@ User Questions:
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请使用上述历史对话中的数据结构信息,在满足下面约束条件下通过数据分析回答用户的问题。
|
||||
请使用上述历史对话中生成的数据结构信息,在满足下面约束条件下通过duckdb sql数据分析回答用户的问题。
|
||||
约束条件:
|
||||
1.请充分理解用户的问题,使用duckdb sql的方式进行分析, 分析内容按下面要求的输出格式返回,sql请输出在对应的sql参数中
|
||||
2.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {disply_type}
|
||||
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql,不要使用没在数据结构中的列名,。
|
||||
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行,请确保这部分内容按要求输出
|
||||
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行,请确保这部分内容按要求输出,不要参考历史信息的返回格式,请按下面要求返回
|
||||
请确保你的输出格式如下:
|
||||
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
|
||||
|
@ -28,23 +28,24 @@ _DEFAULT_TEMPLATE_ZH = """
|
||||
下面是用户文件{file_name}的一部分数据,请学习理解该数据的结构和内容,按要求输出解析结果:
|
||||
{data_example}
|
||||
分析各列数据的含义和作用,并对专业术语进行简单明了的解释, 如果是时间类型请给出时间格式类似:yyyy-MM-dd HH:MM:ss.
|
||||
请不要修改或者翻译列名,确保和给出数据列名一致.
|
||||
将列名作为key,分析解释作为value,生成json数组如[\\{{"列名1": "分析解释内容1"\\}},\\{{"列名2":"分析解释2"\\}}],并输出在返回json内容的ColumnAnalysis属性中.
|
||||
请不要修改或者翻译列名,确保和给出数据列名一致
|
||||
|
||||
提供一些分析方案思路,请一步一步思考。
|
||||
|
||||
请以JSON格式返回您的答案,返回格式如下:
|
||||
请以确保只以JSON格式回答,格式如下:
|
||||
{response}
|
||||
"""
|
||||
|
||||
_RESPONSE_FORMAT_SIMPLE_ZH = {
|
||||
"DataAnalysis": "数据内容分析总结",
|
||||
"ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
"ColumnAnalysis": [{"column name": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"AnalysisProgram": ["1.分析方案1", "2.分析方案2"],
|
||||
}
|
||||
_RESPONSE_FORMAT_SIMPLE_EN = {
|
||||
"DataAnalysis": "Data content analysis summary",
|
||||
"ColumnAnalysis": [{"column name1": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"}],
|
||||
"AnalysisProgram": ["1. Analysis plan 1, chart display type 1", "2. Analysis plan 2, chart display type 2"],
|
||||
"ColumnAnalysis": [{"column name": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"}],
|
||||
"AnalysisProgram": ["1. Analysis plan ", "2. Analysis plan "],
|
||||
}
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE =(_RESPONSE_FORMAT_SIMPLE_EN if CFG.LANGUAGE == "en" else _RESPONSE_FORMAT_SIMPLE_ZH)
|
||||
|
@ -72,6 +72,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
)
|
||||
|
||||
input_values = {
|
||||
"db_name": self.db_name,
|
||||
"user_input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
|
@ -62,7 +62,7 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
err_msg = None
|
||||
try:
|
||||
if not prompt_response.sql or len(prompt_response.sql) <=0:
|
||||
return f""" <span style=\"color:red\">[Unresolvable return]</span>\n{speak}"""
|
||||
return f"""{speak}"""
|
||||
|
||||
df = data(prompt_response.sql)
|
||||
param["type"] = "response_table"
|
||||
|
@ -13,7 +13,10 @@ _PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
|
||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Please create a syntactically correct {dialect} sql based on the user question, use the following tables schema to generate sql:
|
||||
Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database.
|
||||
Database name:
|
||||
{db_name}
|
||||
Table structure definition:
|
||||
{table_info}
|
||||
|
||||
Constraint:
|
||||
@ -31,10 +34,14 @@ Ensure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请根据用户输入问题,使用如下的表结构定义创建一个语法正确的 {dialect} sql:
|
||||
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
|
||||
数据库名:
|
||||
{db_name}
|
||||
表结构定义:
|
||||
{table_info}
|
||||
|
||||
约束:
|
||||
1. 请理解用户意图根据用户输入问题,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要sql,则直接回答用户问题。
|
||||
1. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {top_k} 个结果。
|
||||
2. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||
3. 请注意生成SQL时不要弄错表和列的关系
|
||||
|
Loading…
Reference in New Issue
Block a user