feat(ChatDB): ChatDB Use fintune model

1.Compatible with community pure sql output model
This commit is contained in:
yhjun1026 2023-11-16 11:50:37 +08:00
parent 1acc8b0134
commit fa06ba5bf6
10 changed files with 178 additions and 80 deletions

View File

@ -372,8 +372,8 @@ class ApiCall:
param["type"] = api_status.name
if api_status.args:
param["sql"] = api_status.args["sql"]
if api_status.err_msg:
param["err_msg"] = api_status.err_msg
# if api_status.err_msg:
# param["err_msg"] = api_status.err_msg
if api_status.api_result:
param["data"] = api_status.api_result

View File

@ -153,11 +153,21 @@ class LLMModelAdaper:
else:
raise ValueError(f"Unknown role: {role}")
can_use_system = ""
if system_messages:
# TODO vicuna 兼容 测试完放弃
if len(system_messages) > 1:
can_use_system = system_messages[0]
conv[-1][0][-1] =system_messages[-1]
elif len(system_messages) == 1:
conv[-1][0][-1] = system_messages[-1]
if isinstance(conv, Conversation):
conv.set_system_message("".join(system_messages))
conv.set_system_message(can_use_system)
else:
conv.update_system_message("".join(system_messages))
conv.update_system_message(can_use_system)
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)

View File

@ -26,6 +26,42 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token")
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
wenxin_messages.append({"role": "user", "content": last_usr_message})
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
# build last user messge
if len(system_messages) >0:
if len(system_messages) > 1:
end_message = system_messages[-1]
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
end_message = system_messages[-1] + "\n" + last_message.content
else:
end_message = system_messages[-1]
else:
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages
def wenxin_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@ -40,8 +76,9 @@ def wenxin_generate_stream(
if not model_version:
yield f"Unsupport model version {model_name}"
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
keys:[] = model_params.proxy_api_key.split(";")
proxy_api_key = keys[0]
proxy_api_secret = keys[1]
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -51,40 +88,43 @@ def wenxin_generate_stream(
if not access_token:
yield "Failed to get access token. please set the correct api_key and secret key."
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
system = ""
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
role_define = messages.pop(0)
system = role_define.content
else:
message = messages.pop(0)
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.HUMAN:
# system = ""
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
# role_define = messages.pop(0)
# system = role_define.content
# else:
# message = messages.pop(0)
# if message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# temp_his = history[::-1]
temp_his = history
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
# for message in messages:
# if message.role == ModelMessageRoleType.SYSTEM:
# history.append({"role": "user", "content": message.content})
# # elif message.role == ModelMessageRoleType.HUMAN:
# # history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.AI:
# history.append({"role": "assistant", "content": message.content})
# else:
# pass
#
# # temp_his = history[::-1]
# temp_his = history
# last_user_input = None
# for m in temp_his:
# if m["role"] == "user":
# last_user_input = m
# break
#
# if last_user_input:
# history.remove(last_user_input)
# history.append(last_user_input)
#
history, systems = __convert_2_wenxin_messages(messages)
system = ""
if systems and len(systems)>0:
system = systems[0]
payload = {
"messages": history,
"system": system,

View File

@ -8,6 +8,43 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
wenxin_messages.append({"role": "user", "content": last_usr_message})
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
# build last user messge
if len(system_messages) >0:
if len(system_messages) > 1:
end_message = system_messages[-1]
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
end_message = system_messages[-1] + "\n" + last_message.content
else:
end_message = system_messages[-1]
else:
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@ -22,40 +59,40 @@ def zhipu_generate_stream(
import zhipuai
zhipuai.api_key = proxy_api_key
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
system = ""
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
role_define = messages.pop(0)
system = role_define.content
else:
message = messages.pop(0)
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.HUMAN:
# system = ""
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
# role_define = messages.pop(0)
# system = role_define.content
# else:
# message = messages.pop(0)
# if message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# temp_his = history[::-1]
temp_his = history
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
# for message in messages:
# if message.role == ModelMessageRoleType.SYSTEM:
# history.append({"role": "user", "content": message.content})
# # elif message.role == ModelMessageRoleType.HUMAN:
# # history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.AI:
# history.append({"role": "assistant", "content": message.content})
# else:
# pass
#
# # temp_his = history[::-1]
# temp_his = history
# last_user_input = None
# for m in temp_his:
# if m["role"] == "user":
# last_user_input = m
# break
#
# if last_user_input:
# history.remove(last_user_input)
# history.append(last_user_input)
history, systems = __convert_2_wenxin_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,

View File

@ -209,6 +209,8 @@ class BaseChat(ABC):
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
msg =""
view_msg=""
async for output in worker_manager.generate_stream(payload):
### Plug-in research in result generation
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(

View File

@ -30,13 +30,13 @@ User Questions:
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """
请使用上述历史对话中的数据结构信息在满足下面约束条件下通过数据分析回答用户的问题
请使用上述历史对话中生成的数据结构信息在满足下面约束条件下通过duckdb sql数据分析回答用户的问题
约束条件:
1.请充分理解用户的问题使用duckdb sql的方式进行分析 分析内容按下面要求的输出格式返回sql请输出在对应的sql参数中
2.请从如下给出的展示方式种选择最优的一种用以进行数据渲染将类型名称放入返回要求格式的name参数值种如果找不到最合适的则使用'Table'作为展示方式可用数据展示方式如下: {disply_type}
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql不要使用没在数据结构中的列名
4.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行请确保这部分内容按要求输出
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行请确保这部分内容按要求输出不要参考历史信息的返回格式请按下面要求返回
请确保你的输出格式如下:
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>

View File

@ -28,23 +28,24 @@ _DEFAULT_TEMPLATE_ZH = """
下面是用户文件{file_name}的一部分数据请学习理解该数据的结构和内容按要求输出解析结果:
{data_example}
分析各列数据的含义和作用并对专业术语进行简单明了的解释, 如果是时间类型请给出时间格式类似:yyyy-MM-dd HH:MM:ss.
请不要修改或者翻译列名确保和给出数据列名一致.
将列名作为key分析解释作为value生成json数组如[\\{{"列名1": "分析解释内容1"\\}},\\{{"列名2":"分析解释2"\\}}]并输出在返回json内容的ColumnAnalysis属性中.
请不要修改或者翻译列名确保和给出数据列名一致
提供一些分析方案思路请一步一步思考
请以JSON格式您的返回格式如下
请以确保只以JSON格式回答格式如下
{response}
"""
_RESPONSE_FORMAT_SIMPLE_ZH = {
"DataAnalysis": "数据内容分析总结",
"ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
"ColumnAnalysis": [{"column name": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1", "2.分析方案2"],
}
_RESPONSE_FORMAT_SIMPLE_EN = {
"DataAnalysis": "Data content analysis summary",
"ColumnAnalysis": [{"column name1": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"}],
"AnalysisProgram": ["1. Analysis plan 1, chart display type 1", "2. Analysis plan 2, chart display type 2"],
"ColumnAnalysis": [{"column name": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"}],
"AnalysisProgram": ["1. Analysis plan ", "2. Analysis plan "],
}
RESPONSE_FORMAT_SIMPLE =(_RESPONSE_FORMAT_SIMPLE_EN if CFG.LANGUAGE == "en" else _RESPONSE_FORMAT_SIMPLE_ZH)

View File

@ -72,6 +72,7 @@ class ChatWithDbAutoExecute(BaseChat):
)
input_values = {
"db_name": self.db_name,
"user_input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,

View File

@ -62,7 +62,7 @@ class DbChatOutputParser(BaseOutputParser):
err_msg = None
try:
if not prompt_response.sql or len(prompt_response.sql) <=0:
return f""" <span style=\"color:red\">[Unresolvable return]</span>\n{speak}"""
return f"""{speak}"""
df = data(prompt_response.sql)
param["type"] = "response_table"

View File

@ -13,7 +13,10 @@ _PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
_DEFAULT_TEMPLATE_EN = """
Please create a syntactically correct {dialect} sql based on the user question, use the following tables schema to generate sql:
Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database.
Database name:
{db_name}
Table structure definition:
{table_info}
Constraint:
@ -31,10 +34,14 @@ Ensure the response is correct json and can be parsed by Python json.loads.
"""
_DEFAULT_TEMPLATE_ZH = """
请根据用户输入问题使用如下的表结构定义创建一个语法正确的 {dialect} sql:
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
数据库名:
{db_name}
表结构定义:
{table_info}
约束:
1. 请理解用户意图根据用户输入问题使用给出表结构定义创建一个语法正确的 {dialect} sql如果不需要sql则直接回答用户问题
1. 除非用户在问题中指定了他希望获得的具体数据行数否则始终将查询限制为最多 {top_k} 个结果
2. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
3. 请注意生成SQL时不要弄错表和列的关系