Implemented a new multi-scenario dialogue architecture

This commit is contained in:
yhjun1026 2023-05-31 15:59:50 +08:00
parent 973bcce03c
commit 06bc4452d4
35 changed files with 905 additions and 743 deletions

View File

@ -277,6 +277,7 @@ class Database:
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("sql run:" + command)
cursor = session.execute(text(command))
if cursor.returns_rows:
if fetch == "all":

View File

@ -105,18 +105,14 @@ class Conversation:
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql import MySQLOperator
mo = MySQLOperator(**(DB_SETTINGS))
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"Database {dbname} Schema information as follows: {message}\n"
conv_default = Conversation(
system = None,
roles=("human", "ai"),
messages= (),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_one_shot = Conversation(
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
@ -261,7 +257,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
# question:
# {question}
# """
default_conversation = conv_one_shot
default_conversation = conv_default
chat_mode_title = {
@ -290,7 +286,3 @@ conv_templates = {
"vicuna_v1": conv_vicuna_v1,
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
}
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)

View File

@ -21,22 +21,46 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
}
messages = prompt.split(stop)
# Add history conversation
for i in range(1, len(messages) - 2, 2):
history.append(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
)
history.append(
{
"role": "system",
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
}
)
for message in messages:
if len(message) <= 0:
continue
if "human:" in message:
history.append(
{"role": "user", "content": message.split("human:")[1]},
)
elif "system:" in message:
history.append(
{
"role": "system",
"content": message.split("system:")[1],
}
)
elif "ai:" in message:
history.append(
{
"role": "ai",
"content": message.split("ai:")[1],
}
)
else:
history.append(
{
"role": "system",
"content": message,
}
)
# 把最后一个用户的信息移动到末尾
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
# Add user query
query = messages[-2].split(ROLE_USER + ":")[1]
history.append({"role": "user", "content": query})
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history,

View File

@ -36,7 +36,7 @@ class BaseOutputParser(ABC):
self.sep = sep
self.is_stream_out = is_stream_out
def __post_process_code(code):
def __post_process_code(self, code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
@ -92,7 +92,7 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
print("un_stream clear response:{}", ai_response)
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
@ -140,6 +140,7 @@ class BaseOutputParser(ABC):
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
return cleaned_output
def parse_view_response(self, ai_text) -> str:

View File

@ -31,15 +31,15 @@ DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
class PromptTemplate(BaseModel, ABC):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template_scene: str
template_scene: Optional[str]
template_define: str
template_define: Optional[str]
"""this template define"""
template: str
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: str
response_format: Optional[str]
"""default use stream out"""
stream_out: bool = True
""""""
@ -57,52 +57,12 @@ class PromptTemplate(BaseModel, ABC):
"""Return the prompt type key."""
return "prompt"
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Generate a formatted string representation of a command.
Args:
command (dict): A dictionary containing command information.
Returns:
str: The formatted command string.
"""
args_string = ", ".join(
f'"{key}": "{value}"' for key, value in command["args"].items()
)
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = []
if self.command_registry:
command_strings += [
str(item)
for item in self.command_registry.commands.values()
if item.enabled
]
# terminate command is added manually
command_strings += [self._generate_command_string(item) for item in items]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
if self.template:
if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
def add_goals(self, goal: str) -> None:
self.goals.append(goal)

View File

@ -2,8 +2,10 @@ from enum import Enum
class ChatScene(Enum):
ChatWithDb = "chat_with_db"
ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa"
ChatExecution = "chat_execution"
ChatKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge"
ChatUrlKnowledge = "chat_url_knowledge"
ChatNormal = "chat_normal"

View File

@ -56,7 +56,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(self, chat_mode, chat_session_id, current_user_input):
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
@ -64,12 +64,12 @@ class BaseChat(ABC):
### TODO
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0
self.temperature = temperature
self.max_new_tokens = max_new_tokens
### load chat_session_id's chat historys
self._load_history(self.chat_session_id)
@ -92,15 +92,17 @@ class BaseChat(ABC):
pass
def __call_base(self):
input_values = self.generate_input_values()
input_values = self.generate_input_values()
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now()
# TODO
self.current_message.tokens = 0
current_prompt = None
current_prompt = self.prompt_template.format(**input_values)
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
@ -108,8 +110,8 @@ class BaseChat(ABC):
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
self.current_message.add_system_message(current_prompt)
if current_prompt:
self.current_message.add_system_message(current_prompt)
payload = {
"model": self.llm_model,
@ -118,7 +120,6 @@ class BaseChat(ABC):
"max_new_tokens": int(self.max_new_tokens),
"stop": self.prompt_template.sep,
}
logger.info(f"Requert: \n{payload}")
return payload
def stream_call(self):
@ -127,30 +128,18 @@ class BaseChat(ABC):
ai_response_text = ""
try:
show_info = ""
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
timeout=120,
)
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate_stream"),
# headers=headers,
# json=payload,
# timeout=120,
# )
#
# ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
# for resp_text_trunck in ai_response_text:
# show_info = resp_text_trunck
# yield resp_text_trunck + "▌"
#
#### MOCK TEST
def mock_stream_out():
for i in range(1, 11):
time.sleep(0.5)
yield f"Message:{i}"
for msg in mock_stream_out():
show_info = msg
yield msg + ""
for resp_text_trunck in ai_response_text:
show_info = resp_text_trunck
yield resp_text_trunck + ""
self.current_message.add_ai_message(show_info)
@ -186,13 +175,13 @@ class BaseChat(ABC):
result = self.do_with_prompt_response(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
if prompt_define_response.thoughts.get("speak"):
if hasattr(prompt_define_response.thoughts, "speak"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("speak"), result
)
)
elif prompt_define_response.thoughts.get("reasoning"):
elif hasattr(prompt_define_response.thoughts, "reasoning"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("reasoning"), result
@ -223,15 +212,18 @@ class BaseChat(ABC):
def call(self):
if self.prompt_template.stream_out:
yield self.stream_call()
yield self.stream_call()
else:
return self.nostream_call()
def generate_llm_text(self) -> str:
text = self.prompt_template.template_define + self.prompt_template.sep
### 线处理历史信息
text = ""
if self.prompt_template.template_define:
text = self.prompt_template.template_define + self.prompt_template.sep
### 处理历史信息
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
@ -262,8 +254,8 @@ class BaseChat(ABC):
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
@ -298,34 +290,3 @@ class BaseChat(ABC):
"""
pass
if __name__ == "__main__":
#
# def call_back(t, m):
# print(t)
# print(m)
#
# def my_fn(call_fn, xx):
# call_fn(1, xx)
#
#
# my_fn(call_back, "1231")
def my_generator():
while True:
value = yield
print('Received value:', value)
if value == 'stop':
return
# 创建生成器对象
gen = my_generator()
# 启动生成器
next(gen)
# 发送数据到生成器
gen.send('Hello')
gen.send('World')
gen.send('stop')

View File

@ -0,0 +1,57 @@
import json
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
CFG = Config()
class ChatWithDbAutoExecute(BaseChat):
chat_scene: str = ChatScene.ChatWithDbExecute.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id,
current_user_input=user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
}
return input_values
def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql)
if __name__ == "__main__":
ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}"
ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '')
print(ss)
json.loads(ss)

View File

@ -22,7 +22,9 @@ class DbChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text):
response = json.loads(super().parse_prompt_response(model_out_text))
clean_str = super().parse_prompt_response(model_out_text);
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
sql, thoughts = response["sql"], response["thoughts"]
return SqlAction(sql, thoughts)

View File

@ -1,36 +1,30 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Question: {input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
If the given table is beyond the scope of use, do not use it forcibly.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Question: {input}
"""
@ -49,17 +43,16 @@ RESPONSE_FORMAT = {
}
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
chat_db_prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDb.value,
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
template_define=PROMPT_SCENE_DEFINE,
@ -69,5 +62,5 @@ chat_db_prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})

View File

@ -1,240 +0,0 @@
import requests
import datetime
import threading
import json
import traceback
from urllib.parse import urljoin
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from typing import Any, Iterable, List, Optional
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.out_parser import SqlAction
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_db.prompt import chat_db_prompt
from pilot.out_parser.base import BaseOutputParser
from pilot.scene.chat_db.out_parser import DbChatOutputParser
CFG = Config()
class ChatWithDb(BaseChat):
chat_scene: str = ChatScene.ChatWithDb.value
"""Number of results to return from the query"""
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
}
return input_values
def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql)
# def call(self) -> str:
# input_values = {
# "input": self.current_user_input,
# "top_k": str(self.top_k),
# "dialect": self.database.dialect,
# "table_info": self.database.table_simple_info(self.db_connect),
# # "stop": self.sep_style,
# }
#
# ### Chat sequence advance
# self.current_message.chat_order = len(self.history_message) + 1
# self.current_message.add_user_message(self.current_user_input)
# self.current_message.start_date = datetime.datetime.now()
# # TODO
# self.current_message.tokens = 0
#
# current_prompt = self.prompt_template.format(**input_values)
#
# ### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
# if self.history_message:
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
# logger.info(
# f"There are already {len(self.history_message)} rounds of conversations!"
# )
#
# self.current_message.add_system_message(current_prompt)
#
# payload = {
# "model": self.llm_model,
# "prompt": self.generate_llm_text(),
# "temperature": float(self.temperature),
# "max_new_tokens": int(self.max_new_tokens),
# "stop": self.prompt_template.sep,
# }
# logger.info(f"Requert: \n{payload}")
# ai_response_text = ""
# try:
# ### 走非流式的模型服务接口
#
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate"),
# headers=headers,
# json=payload,
# timeout=120,
# )
# ai_response_text = (
# self.prompt_template.output_parser.parse_model_server_out(response)
# )
# self.current_message.add_ai_message(ai_response_text)
# prompt_define_response = (
# self.prompt_template.output_parser.parse_prompt_response(
# ai_response_text
# )
# )
#
# result = self.database.run(self.db_connect, prompt_define_response.sql)
#
# if hasattr(prompt_define_response, "thoughts"):
# if prompt_define_response.thoughts.get("speak"):
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts.get("speak"), result
# )
# )
# elif prompt_define_response.thoughts.get("reasoning"):
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts.get("reasoning"), result
# )
# )
# else:
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts, result
# )
# )
# else:
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response, result
# )
# )
#
# except Exception as e:
# print(traceback.format_exc())
# logger.error("model response parase faild" + str(e))
# self.current_message.add_view_message(
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
# )
# ### 对话记录存储
# self.memory.append(self.current_message)
def chat_show(self):
ret = []
# 单论对话只能有一次User 记录 和一次 AI 记录
# TODO 推理过程前端展示。。。
for message in self.current_message.messages:
if isinstance(message, HumanMessage):
ret[-1][-2] = message.content
# 是否展示推理过程
if isinstance(message, ViewMessage):
ret[-1][-1] = message.content
return ret
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value
if __name__ == "__main__":
# chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息")
#
# chat.call()
#
# resp = chat.chat_show()
#
# print(vars(resp))
# memory = FileHistoryMemory("test123")
# once1 = OnceConversation()
# once1.add_user_message("问题测试")
# once1.add_system_message("prompt1")
# once1.add_system_message("prompt2")
# once1.chat_order = 1
# once1.set_start_time(datetime.datetime.now())
# memory.append(once1)
#
# once = OnceConversation()
# once.add_user_message("问题测试2")
# once.add_system_message("prompt3")
# once.add_system_message("prompt4")
# once.chat_order = 2
# once.set_start_time(datetime.datetime.now())
# memory.append(once)
db: Database = CFG.local_db
db_connect = db.get_session("gpt-user")
data = db.run(db_connect, "select * from users")
print(generate_htm_table(data))
#
# print(db.run(db_connect, "select * from users"))
#
# #
# # def print_numbers():
# # db_connect1 = db.get_session("dbgpt-test")
# # cursor1 = db_connect1.execute(text("select * from test_name"))
# # if cursor1.returns_rows:
# # result1 = cursor1.fetchall()
# # print( result1)
# #
# #
# # # 创建线程
# # t = threading.Thread(target=print_numbers)
# # # 启动线程
# # t.start()
#
# print(db.run(db_connect, "select * from tran_order"))
#
# print(db.run(db_connect, "select count(*) as aa from tran_order"))
#
# print(db.table_simple_info(db_connect))
# my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# index = 3
# last_three_elements = my_list[-index:]
# print(last_three_elements)

View File

@ -0,0 +1,56 @@
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.professional_qa.prompt import prompt
CFG = Config()
class ChatWithDbQA(BaseChat):
chat_scene: str = ChatScene.ChatWithDbQA.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.db_name = db_name
if db_name:
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5
def generate_input_values(self):
table_info = ""
dialect = "mysql"
if self.db_name:
table_info = self.database.table_simple_info(self.db_connect)
dialect = self.database.dialect
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": dialect,
"table_info": table_info
}
return input_values
def do_with_prompt_response(self, prompt_response):
if self.auto_execute:
return self.database.run(self.db_connect, prompt_response.sql)
else:
return prompt_response

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@ -0,0 +1,48 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
{table_info}
Question: {input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbQA.value,
input_variables=["input", "table_info", "dialect", "top_k"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -10,8 +10,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.commands.command import execute_command
from pilot.prompts.generator import PluginPromptGenerator
from pilot.scene.chat_execution.prompt import chat_plugin_prompt
from pilot.scene.chat_execution.prompt import prompt
CFG = Config()
@ -20,8 +19,12 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator:PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_session_id, user_input, plugin_selector:str=None):
super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input)
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令

View File

@ -20,8 +20,8 @@ class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text))
sql, thoughts = response["command"], response["thoughts"]
return PluginAction(sql, thoughts)
command, thoughts = response["command"], response["thoughts"]
return PluginAction(command, thoughts)
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)

View File

@ -1,4 +1,5 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
@ -50,7 +51,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_NEED_STREAM_OUT = False
chat_plugin_prompt = PromptTemplate(
prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
@ -62,4 +63,4 @@ chat_plugin_prompt = PromptTemplate(
),
)
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,65 +0,0 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
PROMPT_SUFFIX = """
Goals:
{input}
"""
_DEFAULT_TEMPLATE = """
Constraints:
Exclusively use the commands listed in double quotes e.g. "command name"
Reflect on past decisions and strategies to refine your approach.
Constructively self-criticize your big-picture behavior constantly.
{constraints}
Commands:
{commands_infos}
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_NEED_STREAM_OUT = False
chat_plugin_prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})

View File

@ -1,8 +1,14 @@
from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton
from pilot.scene.chat_db.chat import ChatWithDb
import inspect
import importlib
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
class ChatFactory(metaclass=Singleton):
@staticmethod
@ -13,5 +19,5 @@ class ChatFactory(metaclass=Singleton):
if cls.chat_scene == chat_mode:
implementation = cls(**kwargs)
if implementation == None:
raise Exception("Invalid implementation name:" + chat_mode)
raise Exception(f"Invalid implementation name:{chat_mode}")
return implementation

View File

@ -0,0 +1,69 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatNewKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatNewKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.knowledge_name = knowledge_name
vector_store_config = {
"vector_store_name": knowledge_name,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatNewKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1,66 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatDefaultKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1,71 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatUrlKnowledge (BaseChat):
chat_scene: str = ChatScene.ChatUrlKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatUrlKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.url = url
vector_store_config = {
"vector_store_name": url,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path=url,
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
# url soruce in vector
self.knowledge_embedding_client.knowledge_embedding()
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@ -0,0 +1,43 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容:
{context}
问题:
{question}
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatUrlKnowledge.value,
input_variables=["context", "question"],
response_format=None,
template_define=None,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1,43 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_normal.prompt import prompt
CFG = Config()
class ChatNormal(BaseChat):
chat_scene: str = ChatScene.ChatNormal.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input)
def generate_input_values(self):
input_values = {
"input": self.current_user_input
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNormal.value

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass

View File

@ -1,31 +1,33 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
def stream_write_and_read(lst):
# 对lst使用yield from进行可迭代对象的扁平化
yield from lst
while True:
val = yield
lst.append(val)
CFG = Config()
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatNormal.value,
input_variables=["input"],
response_format=None,
template_define=None,
template=None,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
if __name__ == "__main__":
# 创建一个空列表
my_list = []
CFG.prompt_templates.update({prompt.template_scene: prompt})
# 使用生成器写入数据
stream_writer = stream_write_and_read(my_list)
next(stream_writer)
stream_writer.send(10)
print(1)
stream_writer.send(20)
print(2)
stream_writer.send(30)
print(3)
# 使用生成器读取数据
stream_reader = stream_write_and_read(my_list)
next(stream_reader)
print(stream_reader.send(None))
print(stream_reader.send(None))
print(stream_reader.send(None))

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import traceback
import argparse
import datetime
import json
@ -9,7 +9,6 @@ import shutil
import sys
import time
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
@ -216,19 +215,26 @@ def post_process_code(code):
return code
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title['chat_use_plugin'] == selected:
return ChatScene.ChatExecution
elif chat_mode_title['knowledge_qa'] == selected:
mode= param
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge
elif mode == conversation_types["url"]:
return ChatScene.ChatUrlKnowledge
else:
return ChatScene.ChatNormal
else:
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
sql_mode= param
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
return ChatScene.ChatWithDbExecute
else:
return ChatScene.ChatWithDbQA
return ChatScene.ChatNormal
def chatbot_callback(state, message):
print(f"chatbot_callback:{message}")
@ -237,244 +243,99 @@ def chatbot_callback(state, message):
def http_bot(
state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
):
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
print(f"now chat scene:{scene.value}")
model_name = CFG.LLM_MODEL
if ChatScene.ChatWithDb == scene:
logger.info("chat with db mode use new architecture design")
logger.info(f"User message send!{state.conv_id},{selected}")
if chat_mode_title['knowledge_qa'] == selected:
scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected:
scene: ChatScene = get_chat_mode(selected)
else:
scene: ChatScene = get_chat_mode(selected, sql_mode)
print(f"chat scene:{scene.value}")
if ChatScene.ChatWithDbExecute == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatWithDbQA == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = chat.current_ai_response()
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
elif ChatScene.ChatExecution == scene:
logger.info("plugin mode use new architecture design")
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
strem_generate = chat.stream_call()
elif ChatScene.ChatNormal == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNewKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"knowledge_name": knowledge_name
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatUrlKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"url": url_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
for msg in strem_generate:
state.messages[-1][-1] = msg
if not chat.prompt_template.stream_out:
logger.info("not stream out, wait model response!")
state.messages[-1][-1] = chat.nostream_call()
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
logger.info("stream out start!")
try:
stream_gen = chat.stream_call()
for msg in stream_gen:
state.messages[-1][-1] = msg
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# def generate_numbers():
# for i in range(10):
# time.sleep(0.5)
# yield f"Message:{i}"
#
# def showMessage(message):
# return message
#
# for n in generate_numbers():
# state.messages[-1][-1] = n + "▌"
# yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
new_state.conv_id = uuid.uuid4().hex
state = new_state
else:
### 后续对话
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if mode == conversation_types["custome"]:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["url"] and url_input:
print("url: ", url_input)
vector_store_config = {
"vector_store_name": url_input,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path=url_input,
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep
if state.sep_style == SeparatorStyle.SINGLE
else state.sep2,
}
logger.info(f"Requert: \n{payload}")
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
if state.messages[-1][-1].endwith(""):
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
block_css = (
code_highlight_css
@ -556,6 +417,7 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True,
name="db_selector"
).style(container=False)
sql_mode = gr.Radio(
@ -565,6 +427,7 @@ def build_single_model_ui():
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
name="sql_mode"
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
@ -581,7 +444,8 @@ def build_single_model_ui():
value="",
interactive=True,
show_label=True,
type="value"
type="value",
name="plugin_selector"
).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
@ -602,13 +466,14 @@ def build_single_model_ui():
],
show_label=False,
value=llm_native_dialogue,
name="mode"
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False
get_lang_text("configure_knowledge_base"), open=False, visible=False
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input")
def show_url_input(evt:gr.SelectData):
if evt.value == url_knowledge_dialogue:
return gr.update(visible=True)
@ -619,7 +484,7 @@ def build_single_model_ui():
with vs_setting:
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name"
)
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec:
@ -664,10 +529,14 @@ def build_single_model_ui():
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name]
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -676,7 +545,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens]+ params,
[state, chatbot] + btn_list,
)
@ -684,7 +553,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, selected, temperature, max_output_tokens]+ params,
[state, chatbot] + btn_list,
)
vs_add.click(

View File

@ -14,6 +14,7 @@ fonttools==4.38.0
frozenlist==1.3.3
huggingface-hub==0.13.4
importlib-resources==5.12.0
kiwisolver==1.4.4
matplotlib==3.7.0
multidict==6.0.4