mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-22 10:08:34 +00:00
Implemented a new multi-scenario dialogue architecture
This commit is contained in:
parent
973bcce03c
commit
06bc4452d4
@ -277,6 +277,7 @@ class Database:
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("sql run:" + command)
|
||||
cursor = session.execute(text(command))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
|
@ -105,18 +105,14 @@ class Conversation:
|
||||
}
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
|
||||
mo = MySQLOperator(**(DB_SETTINGS))
|
||||
|
||||
message = ""
|
||||
|
||||
schemas = mo.get_schema(dbname)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
return f"Database {dbname} Schema information as follows: {message}\n"
|
||||
|
||||
conv_default = Conversation(
|
||||
system = None,
|
||||
roles=("human", "ai"),
|
||||
messages= (),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
@ -261,7 +257,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
|
||||
# question:
|
||||
# {question}
|
||||
# """
|
||||
default_conversation = conv_one_shot
|
||||
default_conversation = conv_default
|
||||
|
||||
|
||||
chat_mode_title = {
|
||||
@ -290,7 +286,3 @@ conv_templates = {
|
||||
"vicuna_v1": conv_vicuna_v1,
|
||||
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
message = gen_sqlgen_conversation("dbgpt")
|
||||
print(message)
|
||||
|
@ -21,22 +21,46 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
}
|
||||
|
||||
messages = prompt.split(stop)
|
||||
|
||||
# Add history conversation
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
history.append(
|
||||
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
|
||||
)
|
||||
history.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
||||
}
|
||||
)
|
||||
for message in messages:
|
||||
if len(message) <= 0:
|
||||
continue
|
||||
if "human:" in message:
|
||||
history.append(
|
||||
{"role": "user", "content": message.split("human:")[1]},
|
||||
)
|
||||
elif "system:" in message:
|
||||
history.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": message.split("system:")[1],
|
||||
}
|
||||
)
|
||||
elif "ai:" in message:
|
||||
history.append(
|
||||
{
|
||||
"role": "ai",
|
||||
"content": message.split("ai:")[1],
|
||||
}
|
||||
)
|
||||
else:
|
||||
history.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
}
|
||||
)
|
||||
|
||||
# 把最后一个用户的信息移动到末尾
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
# Add user query
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
history.append({"role": "user", "content": query})
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"messages": history,
|
||||
|
@ -36,7 +36,7 @@ class BaseOutputParser(ABC):
|
||||
self.sep = sep
|
||||
self.is_stream_out = is_stream_out
|
||||
|
||||
def __post_process_code(code):
|
||||
def __post_process_code(self, code):
|
||||
sep = "\n```"
|
||||
if sep in code:
|
||||
blocks = code.split(sep)
|
||||
@ -92,7 +92,7 @@ class BaseOutputParser(ABC):
|
||||
ai_response = ai_response.replace("\n", "")
|
||||
ai_response = ai_response.replace("\_", "_")
|
||||
ai_response = ai_response.replace("\*", "*")
|
||||
print("un_stream clear response:{}", ai_response)
|
||||
print("un_stream ai response:", ai_response)
|
||||
return ai_response
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
@ -140,6 +140,7 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
|
@ -31,15 +31,15 @@ DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
class PromptTemplate(BaseModel, ABC):
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
template_scene: str
|
||||
template_scene: Optional[str]
|
||||
|
||||
template_define: str
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
template: str
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format: str
|
||||
response_format: Optional[str]
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
@ -57,52 +57,12 @@ class PromptTemplate(BaseModel, ABC):
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
Args:
|
||||
command (dict): A dictionary containing command information.
|
||||
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
)
|
||||
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self.command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self.command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
command_strings += [self._generate_command_string(item) for item in items]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
|
||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
def add_goals(self, goal: str) -> None:
|
||||
self.goals.append(goal)
|
||||
|
@ -2,8 +2,10 @@ from enum import Enum
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDb = "chat_with_db"
|
||||
ChatWithDbExecute = "chat_with_db_execute"
|
||||
ChatWithDbQA = "chat_with_db_qa"
|
||||
ChatExecution = "chat_execution"
|
||||
ChatKnowledge = "chat_default_knowledge"
|
||||
ChatNewKnowledge = "chat_new_knowledge"
|
||||
ChatUrlKnowledge = "chat_url_knowledge"
|
||||
ChatNormal = "chat_normal"
|
||||
|
@ -56,7 +56,7 @@ class BaseChat(ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, chat_mode, chat_session_id, current_user_input):
|
||||
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
@ -64,12 +64,12 @@ class BaseChat(ABC):
|
||||
### TODO
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.current_message: OnceConversation = OnceConversation()
|
||||
self.current_tokens_used: int = 0
|
||||
self.temperature = temperature
|
||||
self.max_new_tokens = max_new_tokens
|
||||
### load chat_session_id's chat historys
|
||||
self._load_history(self.chat_session_id)
|
||||
|
||||
@ -92,15 +92,17 @@ class BaseChat(ABC):
|
||||
pass
|
||||
|
||||
def __call_base(self):
|
||||
input_values = self.generate_input_values()
|
||||
input_values = self.generate_input_values()
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
self.current_message.start_date = datetime.datetime.now()
|
||||
# TODO
|
||||
self.current_message.tokens = 0
|
||||
current_prompt = None
|
||||
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
if self.prompt_template.template:
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
|
||||
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||
if self.history_message:
|
||||
@ -108,8 +110,8 @@ class BaseChat(ABC):
|
||||
logger.info(
|
||||
f"There are already {len(self.history_message)} rounds of conversations!"
|
||||
)
|
||||
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
if current_prompt:
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
@ -118,7 +120,6 @@ class BaseChat(ABC):
|
||||
"max_new_tokens": int(self.max_new_tokens),
|
||||
"stop": self.prompt_template.sep,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
return payload
|
||||
|
||||
def stream_call(self):
|
||||
@ -127,30 +128,18 @@ class BaseChat(ABC):
|
||||
ai_response_text = ""
|
||||
try:
|
||||
show_info = ""
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# response = requests.post(
|
||||
# urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
# headers=headers,
|
||||
# json=payload,
|
||||
# timeout=120,
|
||||
# )
|
||||
#
|
||||
# ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
|
||||
# for resp_text_trunck in ai_response_text:
|
||||
# show_info = resp_text_trunck
|
||||
# yield resp_text_trunck + "▌"
|
||||
#
|
||||
|
||||
#### MOCK TEST
|
||||
def mock_stream_out():
|
||||
for i in range(1, 11):
|
||||
time.sleep(0.5)
|
||||
yield f"Message:{i}"
|
||||
|
||||
for msg in mock_stream_out():
|
||||
show_info = msg
|
||||
yield msg + "▌"
|
||||
for resp_text_trunck in ai_response_text:
|
||||
show_info = resp_text_trunck
|
||||
yield resp_text_trunck + "▌"
|
||||
|
||||
self.current_message.add_ai_message(show_info)
|
||||
|
||||
@ -186,13 +175,13 @@ class BaseChat(ABC):
|
||||
result = self.do_with_prompt_response(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if prompt_define_response.thoughts.get("speak"):
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response.thoughts.get("speak"), result
|
||||
)
|
||||
)
|
||||
elif prompt_define_response.thoughts.get("reasoning"):
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
self.current_message.add_view_message(
|
||||
self.prompt_template.output_parser.parse_view_response(
|
||||
prompt_define_response.thoughts.get("reasoning"), result
|
||||
@ -223,15 +212,18 @@ class BaseChat(ABC):
|
||||
|
||||
def call(self):
|
||||
if self.prompt_template.stream_out:
|
||||
yield self.stream_call()
|
||||
yield self.stream_call()
|
||||
else:
|
||||
return self.nostream_call()
|
||||
|
||||
def generate_llm_text(self) -> str:
|
||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||
### 线处理历史信息
|
||||
text = ""
|
||||
if self.prompt_template.template_define:
|
||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||
|
||||
### 处理历史信息
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
### 使用历史信息的第一轮和最后一轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
for first_message in self.history_message[0].messages:
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
text += (
|
||||
@ -262,8 +254,8 @@ class BaseChat(ABC):
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
### current conversation
|
||||
|
||||
for now_message in self.current_message.messages:
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
@ -298,34 +290,3 @@ class BaseChat(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#
|
||||
# def call_back(t, m):
|
||||
# print(t)
|
||||
# print(m)
|
||||
#
|
||||
# def my_fn(call_fn, xx):
|
||||
# call_fn(1, xx)
|
||||
#
|
||||
#
|
||||
# my_fn(call_back, "1231")
|
||||
|
||||
def my_generator():
|
||||
while True:
|
||||
value = yield
|
||||
print('Received value:', value)
|
||||
if value == 'stop':
|
||||
return
|
||||
|
||||
|
||||
# 创建生成器对象
|
||||
gen = my_generator()
|
||||
|
||||
# 启动生成器
|
||||
next(gen)
|
||||
|
||||
# 发送数据到生成器
|
||||
gen.send('Hello')
|
||||
gen.send('World')
|
||||
gen.send('stop')
|
||||
|
0
pilot/scene/chat_db/auto_execute/__init__.py
Normal file
0
pilot/scene/chat_db/auto_execute/__init__.py
Normal file
57
pilot/scene/chat_db/auto_execute/chat.py
Normal file
57
pilot/scene/chat_db/auto_execute/chat.py
Normal file
@ -0,0 +1,57 @@
|
||||
import json
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithDbAutoExecute(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatWithDbExecute.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
if not db_name:
|
||||
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
|
||||
self.db_name = db_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect)
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return self.database.run(self.db_connect, prompt_response.sql)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}"
|
||||
ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '')
|
||||
print(ss)
|
||||
json.loads(ss)
|
@ -22,7 +22,9 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
clean_str = super().parse_prompt_response(model_out_text);
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
sql, thoughts = response["sql"], response["thoughts"]
|
||||
return SqlAction(sql, thoughts)
|
||||
|
@ -1,36 +1,30 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
||||
|
||||
PROMPT_SUFFIX = """Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||
You can order the results by a relevant column to return the most interesting examples in the database.
|
||||
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
||||
If the given table is beyond the scope of use, do not use it forcibly.
|
||||
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
|
||||
"""
|
||||
|
||||
_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
|
||||
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
|
||||
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
|
||||
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
|
||||
PROMPT_SUFFIX = """Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}
|
||||
|
||||
"""
|
||||
|
||||
@ -49,17 +43,16 @@ RESPONSE_FORMAT = {
|
||||
}
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"thoughts": "thoughts summary to say to user",
|
||||
"thoughts": "thoughts summary to say to user",
|
||||
"sql": "SQL Query to run",
|
||||
}
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
chat_db_prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDb.value,
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbExecute.value,
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
@ -69,5 +62,5 @@ chat_db_prompt = PromptTemplate(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
CFG.prompt_templates.update({chat_db_prompt.template_scene: chat_db_prompt})
|
@ -1,240 +0,0 @@
|
||||
import requests
|
||||
import datetime
|
||||
import threading
|
||||
import json
|
||||
import traceback
|
||||
from urllib.parse import urljoin
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_db.out_parser import SqlAction
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
from pilot.scene.chat_db.prompt import chat_db_prompt
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithDb(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatWithDb.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input)
|
||||
if not db_name:
|
||||
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
|
||||
self.db_name = db_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect)
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return self.database.run(self.db_connect, prompt_response.sql)
|
||||
|
||||
# def call(self) -> str:
|
||||
# input_values = {
|
||||
# "input": self.current_user_input,
|
||||
# "top_k": str(self.top_k),
|
||||
# "dialect": self.database.dialect,
|
||||
# "table_info": self.database.table_simple_info(self.db_connect),
|
||||
# # "stop": self.sep_style,
|
||||
# }
|
||||
#
|
||||
# ### Chat sequence advance
|
||||
# self.current_message.chat_order = len(self.history_message) + 1
|
||||
# self.current_message.add_user_message(self.current_user_input)
|
||||
# self.current_message.start_date = datetime.datetime.now()
|
||||
# # TODO
|
||||
# self.current_message.tokens = 0
|
||||
#
|
||||
# current_prompt = self.prompt_template.format(**input_values)
|
||||
#
|
||||
# ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||
# if self.history_message:
|
||||
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||
# logger.info(
|
||||
# f"There are already {len(self.history_message)} rounds of conversations!"
|
||||
# )
|
||||
#
|
||||
# self.current_message.add_system_message(current_prompt)
|
||||
#
|
||||
# payload = {
|
||||
# "model": self.llm_model,
|
||||
# "prompt": self.generate_llm_text(),
|
||||
# "temperature": float(self.temperature),
|
||||
# "max_new_tokens": int(self.max_new_tokens),
|
||||
# "stop": self.prompt_template.sep,
|
||||
# }
|
||||
# logger.info(f"Requert: \n{payload}")
|
||||
# ai_response_text = ""
|
||||
# try:
|
||||
# ### 走非流式的模型服务接口
|
||||
#
|
||||
# response = requests.post(
|
||||
# urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
# headers=headers,
|
||||
# json=payload,
|
||||
# timeout=120,
|
||||
# )
|
||||
# ai_response_text = (
|
||||
# self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
# )
|
||||
# self.current_message.add_ai_message(ai_response_text)
|
||||
# prompt_define_response = (
|
||||
# self.prompt_template.output_parser.parse_prompt_response(
|
||||
# ai_response_text
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# result = self.database.run(self.db_connect, prompt_define_response.sql)
|
||||
#
|
||||
# if hasattr(prompt_define_response, "thoughts"):
|
||||
# if prompt_define_response.thoughts.get("speak"):
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts.get("speak"), result
|
||||
# )
|
||||
# )
|
||||
# elif prompt_define_response.thoughts.get("reasoning"):
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts.get("reasoning"), result
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts, result
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response, result
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(traceback.format_exc())
|
||||
# logger.error("model response parase faild!" + str(e))
|
||||
# self.current_message.add_view_message(
|
||||
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
# )
|
||||
# ### 对话记录存储
|
||||
# self.memory.append(self.current_message)
|
||||
|
||||
def chat_show(self):
|
||||
ret = []
|
||||
# 单论对话只能有一次User 记录 和一次 AI 记录
|
||||
# TODO 推理过程前端展示。。。
|
||||
for message in self.current_message.messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
ret[-1][-2] = message.content
|
||||
# 是否展示推理过程
|
||||
if isinstance(message, ViewMessage):
|
||||
ret[-1][-1] = message.content
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatExecution.value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# chat: ChatWithDb = ChatWithDb("chat123", "gpt-user", "查询用户信息")
|
||||
#
|
||||
# chat.call()
|
||||
#
|
||||
# resp = chat.chat_show()
|
||||
#
|
||||
# print(vars(resp))
|
||||
|
||||
# memory = FileHistoryMemory("test123")
|
||||
# once1 = OnceConversation()
|
||||
# once1.add_user_message("问题测试")
|
||||
# once1.add_system_message("prompt1")
|
||||
# once1.add_system_message("prompt2")
|
||||
# once1.chat_order = 1
|
||||
# once1.set_start_time(datetime.datetime.now())
|
||||
# memory.append(once1)
|
||||
#
|
||||
# once = OnceConversation()
|
||||
# once.add_user_message("问题测试2")
|
||||
# once.add_system_message("prompt3")
|
||||
# once.add_system_message("prompt4")
|
||||
# once.chat_order = 2
|
||||
# once.set_start_time(datetime.datetime.now())
|
||||
# memory.append(once)
|
||||
|
||||
db: Database = CFG.local_db
|
||||
db_connect = db.get_session("gpt-user")
|
||||
data = db.run(db_connect, "select * from users")
|
||||
print(generate_htm_table(data))
|
||||
|
||||
#
|
||||
# print(db.run(db_connect, "select * from users"))
|
||||
#
|
||||
# #
|
||||
# # def print_numbers():
|
||||
# # db_connect1 = db.get_session("dbgpt-test")
|
||||
# # cursor1 = db_connect1.execute(text("select * from test_name"))
|
||||
# # if cursor1.returns_rows:
|
||||
# # result1 = cursor1.fetchall()
|
||||
# # print( result1)
|
||||
# #
|
||||
# #
|
||||
# # # 创建线程
|
||||
# # t = threading.Thread(target=print_numbers)
|
||||
# # # 启动线程
|
||||
# # t.start()
|
||||
#
|
||||
# print(db.run(db_connect, "select * from tran_order"))
|
||||
#
|
||||
# print(db.run(db_connect, "select count(*) as aa from tran_order"))
|
||||
#
|
||||
# print(db.table_simple_info(db_connect))
|
||||
# my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
# index = 3
|
||||
# last_three_elements = my_list[-index:]
|
||||
# print(last_three_elements)
|
0
pilot/scene/chat_db/professional_qa/__init__.py
Normal file
0
pilot/scene/chat_db/professional_qa/__init__.py
Normal file
56
pilot/scene/chat_db/professional_qa/chat.py
Normal file
56
pilot/scene/chat_db/professional_qa/chat.py
Normal file
@ -0,0 +1,56 @@
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithDbQA(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatWithDbQA.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
self.db_name = db_name
|
||||
if db_name:
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
|
||||
table_info = ""
|
||||
dialect = "mysql"
|
||||
if self.db_name:
|
||||
table_info = self.database.table_simple_info(self.db_connect)
|
||||
dialect = self.database.dialect
|
||||
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": dialect,
|
||||
"table_info": table_info
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
if self.auto_execute:
|
||||
return self.database.run(self.db_connect, prompt_response.sql)
|
||||
else:
|
||||
return prompt_response
|
22
pilot/scene/chat_db/professional_qa/out_parser.py
Normal file
22
pilot/scene/chat_db/professional_qa/out_parser.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
48
pilot/scene/chat_db/professional_qa/prompt.py
Normal file
48
pilot/scene/chat_db/professional_qa/prompt.py
Normal file
@ -0,0 +1,48 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.professional_qa.out_parser import NormalChatOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
|
||||
|
||||
PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
|
||||
{table_info}
|
||||
|
||||
Question: {input}
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||
You can order the results by a relevant column to return the most interesting examples in the database.
|
||||
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
||||
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbQA.value,
|
||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
@ -10,8 +10,7 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command import execute_command
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
|
||||
from pilot.scene.chat_execution.prompt import chat_plugin_prompt
|
||||
from pilot.scene.chat_execution.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -20,8 +19,12 @@ class ChatWithPlugin(BaseChat):
|
||||
plugins_prompt_generator:PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_session_id, user_input, plugin_selector:str=None):
|
||||
super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input)
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# 加载插件中可用命令
|
||||
|
@ -20,8 +20,8 @@ class PluginChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
sql, thoughts = response["command"], response["thoughts"]
|
||||
return PluginAction(sql, thoughts)
|
||||
command, thoughts = response["command"], response["thoughts"]
|
||||
return PluginAction(command, thoughts)
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
@ -50,7 +51,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
chat_plugin_prompt = PromptTemplate(
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExecution.value,
|
||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
@ -62,4 +63,4 @@ chat_plugin_prompt = PromptTemplate(
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
@ -1,65 +0,0 @@
|
||||
import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications."""
|
||||
|
||||
PROMPT_SUFFIX = """
|
||||
Goals:
|
||||
{input}
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Constraints:
|
||||
Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
Reflect on past decisions and strategies to refine your approach.
|
||||
Constructively self-criticize your big-picture behavior constantly.
|
||||
{constraints}
|
||||
|
||||
Commands:
|
||||
{commands_infos}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
chat_plugin_prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExecution.value,
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({chat_plugin_prompt.template_scene: chat_plugin_prompt})
|
@ -1,8 +1,14 @@
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.scene.chat_db.chat import ChatWithDb
|
||||
import inspect
|
||||
import importlib
|
||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||
|
||||
from pilot.scene.chat_normal.chat import ChatNormal
|
||||
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
|
||||
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||
|
||||
class ChatFactory(metaclass=Singleton):
|
||||
@staticmethod
|
||||
@ -13,5 +19,5 @@ class ChatFactory(metaclass=Singleton):
|
||||
if cls.chat_scene == chat_mode:
|
||||
implementation = cls(**kwargs)
|
||||
if implementation == None:
|
||||
raise Exception("Invalid implementation name:" + chat_mode)
|
||||
raise Exception(f"Invalid implementation name:{chat_mode}")
|
||||
return implementation
|
||||
|
69
pilot/scene/chat_knowledge/custom/chat.py
Normal file
69
pilot/scene/chat_knowledge/custom/chat.py
Normal file
@ -0,0 +1,69 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatNewKnowledge (BaseChat):
|
||||
chat_scene: str = ChatScene.ChatNewKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
self.knowledge_name = knowledge_name
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_name,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNewKnowledge.value
|
22
pilot/scene/chat_knowledge/custom/out_parser.py
Normal file
22
pilot/scene/chat_knowledge/custom/out_parser.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
43
pilot/scene/chat_knowledge/custom/prompt.py
Normal file
43
pilot/scene/chat_knowledge/custom/prompt.py
Normal file
@ -0,0 +1,43 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatNewKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
66
pilot/scene/chat_knowledge/default/chat.py
Normal file
66
pilot/scene/chat_knowledge/default/chat.py
Normal file
@ -0,0 +1,66 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatDefaultKnowledge (BaseChat):
|
||||
chat_scene: str = ChatScene.ChatKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
vector_store_config = {
|
||||
"vector_store_name": "default",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value
|
22
pilot/scene/chat_knowledge/default/out_parser.py
Normal file
22
pilot/scene/chat_knowledge/default/out_parser.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
43
pilot/scene/chat_knowledge/default/prompt.py
Normal file
43
pilot/scene/chat_knowledge/default/prompt.py
Normal file
@ -0,0 +1,43 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
71
pilot/scene/chat_knowledge/url/chat.py
Normal file
71
pilot/scene/chat_knowledge/url/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatUrlKnowledge (BaseChat):
|
||||
chat_scene: str = ChatScene.ChatUrlKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=url,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
# url soruce in vector
|
||||
self.knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatUrlKnowledge.value
|
22
pilot/scene/chat_knowledge/url/out_parser.py
Normal file
22
pilot/scene/chat_knowledge/url/out_parser.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
43
pilot/scene/chat_knowledge/url/prompt.py
Normal file
43
pilot/scene/chat_knowledge/url/prompt.py
Normal file
@ -0,0 +1,43 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatUrlKnowledge.value,
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
43
pilot/scene/chat_normal/chat.py
Normal file
43
pilot/scene/chat_normal/chat.py
Normal file
@ -0,0 +1,43 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatNormal(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatNormal.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNormal.value
|
22
pilot/scene/chat_normal/out_parser.py
Normal file
22
pilot/scene/chat_normal/out_parser.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
@ -1,31 +1,33 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
|
||||
def stream_write_and_read(lst):
|
||||
# 对lst使用yield from进行可迭代对象的扁平化
|
||||
yield from lst
|
||||
while True:
|
||||
val = yield
|
||||
lst.append(val)
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatNormal.value,
|
||||
input_variables=["input"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template=None,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个空列表
|
||||
my_list = []
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
# 使用生成器写入数据
|
||||
stream_writer = stream_write_and_read(my_list)
|
||||
next(stream_writer)
|
||||
stream_writer.send(10)
|
||||
print(1)
|
||||
stream_writer.send(20)
|
||||
print(2)
|
||||
stream_writer.send(30)
|
||||
print(3)
|
||||
|
||||
# 使用生成器读取数据
|
||||
stream_reader = stream_write_and_read(my_list)
|
||||
next(stream_reader)
|
||||
print(stream_reader.send(None))
|
||||
print(stream_reader.send(None))
|
||||
print(stream_reader.send(None))
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import traceback
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
@ -9,7 +9,6 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
@ -216,19 +215,26 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
||||
if chat_mode_title['chat_use_plugin'] == selected:
|
||||
return ChatScene.ChatExecution
|
||||
elif chat_mode_title['knowledge_qa'] == selected:
|
||||
mode= param
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"]:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif mode == conversation_types["url"]:
|
||||
return ChatScene.ChatUrlKnowledge
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
else:
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
sql_mode= param
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
return ChatScene.ChatWithDbExecute
|
||||
else:
|
||||
return ChatScene.ChatWithDbQA
|
||||
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
def chatbot_callback(state, message):
|
||||
print(f"chatbot_callback:{message}")
|
||||
@ -237,244 +243,99 @@ def chatbot_callback(state, message):
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request
|
||||
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
|
||||
print(f"now chat scene:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
if ChatScene.ChatWithDb == scene:
|
||||
logger.info("chat with db mode use new architecture design!")
|
||||
logger.info(f"User message send!{state.conv_id},{selected}")
|
||||
if chat_mode_title['knowledge_qa'] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected, mode)
|
||||
elif chat_mode_title['chat_use_plugin'] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected)
|
||||
else:
|
||||
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
||||
print(f"chat scene:{scene.value}")
|
||||
|
||||
if ChatScene.ChatWithDbExecute == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatWithDbQA == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
|
||||
state.messages[-1][-1] = chat.current_ai_response()
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
elif ChatScene.ChatExecution == scene:
|
||||
logger.info("plugin mode use new architecture design!")
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"plugin_selector": plugin_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
strem_generate = chat.stream_call()
|
||||
elif ChatScene.ChatNormal == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatNewKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"knowledge_name": knowledge_name
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatUrlKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"url": url_input
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
|
||||
for msg in strem_generate:
|
||||
state.messages[-1][-1] = msg
|
||||
if not chat.prompt_template.stream_out:
|
||||
logger.info("not stream out, wait model response!")
|
||||
state.messages[-1][-1] = chat.nostream_call()
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
else:
|
||||
logger.info("stream out start!")
|
||||
try:
|
||||
stream_gen = chat.stream_call()
|
||||
for msg in stream_gen:
|
||||
state.messages[-1][-1] = msg
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = "Error:" + str(e)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# def generate_numbers():
|
||||
# for i in range(10):
|
||||
# time.sleep(0.5)
|
||||
# yield f"Message:{i}"
|
||||
#
|
||||
# def showMessage(message):
|
||||
# return message
|
||||
#
|
||||
# for n in generate_numbers():
|
||||
# state.messages[-1][-1] = n + "▌"
|
||||
# yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
else:
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
state = new_state
|
||||
else:
|
||||
### 后续对话
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if mode == conversation_types["custome"]:
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
vector_store_config = {
|
||||
"vector_store_name": "default",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
if mode == conversation_types["custome"] and not db_selector:
|
||||
print("vector store name: ", vector_store_name["vs_name"])
|
||||
vector_store_config = {
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
if mode == conversation_types["url"] and url_input:
|
||||
print("url: ", url_input)
|
||||
vector_store_config = {
|
||||
"vector_store_name": url_input,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=url_input,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep
|
||||
if state.sep_style == SeparatorStyle.SINGLE
|
||||
else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
# 流式输出
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
) * 5
|
||||
else:
|
||||
output = (
|
||||
data["text"] + f" (error_code: {data['error_code']})"
|
||||
)
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
if state.messages[-1][-1].endwith("▌"):
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
@ -556,6 +417,7 @@ def build_single_model_ui():
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
name="db_selector"
|
||||
).style(container=False)
|
||||
|
||||
sql_mode = gr.Radio(
|
||||
@ -565,6 +427,7 @@ def build_single_model_ui():
|
||||
],
|
||||
show_label=False,
|
||||
value=get_lang_text("sql_generate_mode_none"),
|
||||
name="sql_mode"
|
||||
)
|
||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
@ -581,7 +444,8 @@ def build_single_model_ui():
|
||||
value="",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
type="value"
|
||||
type="value",
|
||||
name="plugin_selector"
|
||||
).style(container=False)
|
||||
|
||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
@ -602,13 +466,14 @@ def build_single_model_ui():
|
||||
],
|
||||
show_label=False,
|
||||
value=llm_native_dialogue,
|
||||
name="mode"
|
||||
)
|
||||
vs_setting = gr.Accordion(
|
||||
get_lang_text("configure_knowledge_base"), open=False
|
||||
get_lang_text("configure_knowledge_base"), open=False, visible=False
|
||||
)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
|
||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True)
|
||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input")
|
||||
def show_url_input(evt:gr.SelectData):
|
||||
if evt.value == url_knowledge_dialogue:
|
||||
return gr.update(visible=True)
|
||||
@ -619,7 +484,7 @@ def build_single_model_ui():
|
||||
|
||||
with vs_setting:
|
||||
vs_name = gr.Textbox(
|
||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name"
|
||||
)
|
||||
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
|
||||
with gr.Column() as doc2vec:
|
||||
@ -664,10 +529,14 @@ def build_single_model_ui():
|
||||
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name]
|
||||
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens] + params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@ -676,7 +545,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens]+ params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@ -684,7 +553,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens]+ params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
|
@ -14,6 +14,7 @@ fonttools==4.38.0
|
||||
frozenlist==1.3.3
|
||||
huggingface-hub==0.13.4
|
||||
importlib-resources==5.12.0
|
||||
|
||||
kiwisolver==1.4.4
|
||||
matplotlib==3.7.0
|
||||
multidict==6.0.4
|
||||
|
Loading…
Reference in New Issue
Block a user