Merge branch 'summary' into dev

# Conflicts:
#	pilot/model/llm_out/guanaco_llm.py
#	pilot/out_parser/base.py
#	pilot/scene/base_chat.py
This commit is contained in:
yhjun1026 2023-06-01 19:15:00 +08:00
commit 7e6cf9c9e0
43 changed files with 308 additions and 276 deletions

View File

@ -55,8 +55,6 @@ def fix_and_parse_json(
logger.error("参数解析错误", e) logger.error("参数解析错误", e)
def correct_json(json_to_load: str) -> str: def correct_json(json_to_load: str) -> str:
""" """
Correct common JSON errors. Correct common JSON errors.

View File

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from collections import deque from collections import deque
class FixedSizeDict(OrderedDict): class FixedSizeDict(OrderedDict):
def __init__(self, max_size): def __init__(self, max_size):
super().__init__() super().__init__()
@ -11,6 +12,7 @@ class FixedSizeDict(OrderedDict):
self.popitem(last=False) self.popitem(last=False)
super().__setitem__(key, value) super().__setitem__(key, value)
class FixedSizeList: class FixedSizeList:
def __init__(self, max_size): def __init__(self, max_size):
self.max_size = max_size self.max_size = max_size

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import sqlparse import sqlparse
import regex as re import regex as re
import warnings import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra from pydantic import BaseModel, Field, root_validator, validator, Extra
@ -283,16 +283,16 @@ class Database:
return f"Error: {e}" return f"Error: {e}"
def __write(self, session, write_sql): def __write(self, session, write_sql):
print(f"Write[{write_sql}]") print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session) db_cache = self.get_session_db(session)
result = session.execute(text(write_sql)) result = session.execute(text(write_sql))
session.commit() session.commit()
#TODO Subsequent optimization of dynamically specified database submission loss target problem # TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`")) session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}") print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount return result.rowcount
def __query(self,session, query, fetch: str = "all"): def __query(self, session, query, fetch: str = "all"):
""" """
only for query only for query
Args: Args:
@ -390,37 +390,44 @@ class Database:
cmd_type = parts[0] cmd_type = parts[0]
# 根据命令类型进行处理 # 根据命令类型进行处理
if cmd_type == 'insert': if cmd_type == "insert":
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()) match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
)
if match: if match:
table_name, columns, values = match.groups() table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值 # 将字段列表和值列表分割为单独的字段和值
columns = columns.split(',') columns = columns.split(",")
values = values.split(',') values = values.split(",")
# 构造 WHERE 子句 # 构造 WHERE 子句
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)]) where_clause = " AND ".join(
return f'SELECT * FROM {table_name} WHERE {where_clause}' [
f"{col.strip()}={val.strip()}"
for col, val in zip(columns, values)
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
elif cmd_type == 'delete': elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ... table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据 # 返回一个select语句它选择该表的所有数据
return f'SELECT * FROM {table_name}' return f"SELECT * FROM {table_name}"
elif cmd_type == 'update': elif cmd_type == "update":
table_name = parts[1] table_name = parts[1]
set_idx = parts.index('set') set_idx = parts.index("set")
where_idx = parts.index('where') where_idx = parts.index("where")
# 截取 `set` 子句中的字段名 # 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip() set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句 # 截取 `where` 之后的条件语句
where_clause = ' '.join(parts[where_idx + 1:]) where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据 # 返回一个select语句它选择更新的数据
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}' return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else: else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}") raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql): def __sql_parse(self, sql):
sql = sql.strip() sql = sql.strip()
parsed = sqlparse.parse(sql)[0] parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type() sql_type = parsed.get_type()
@ -429,8 +436,6 @@ class Database:
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}") print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type return parsed, ttype, sql_type
def get_indexes(self, table_name): def get_indexes(self, table_name):
"""Get table indexes about specified table.""" """Get table indexes about specified table."""
session = self._db_sessions() session = self._db_sessions()

View File

@ -103,8 +103,12 @@ class Config(metaclass=Singleton):
else: else:
self.plugins_denylist = [] self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration ### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True" self.NATIVE_SQL_CAN_RUN_DDL = (
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True" os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
)
self.NATIVE_SQL_CAN_RUN_WRITE = (
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
)
### Local database connection configuration ### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")

View File

@ -11,11 +11,9 @@ class BaseConnect(BaseModel, ABC):
type type
driver: str driver: str
def get_session(self, db_name: str): def get_session(self, db_name: str):
pass pass
def get_table_names(self) -> Iterable[str]: def get_table_names(self) -> Iterable[str]:
pass pass

View File

@ -11,8 +11,7 @@ class MySQLConnect(RDBMSDatabase):
Usage: Usage:
""" """
type:str = "MySQL" type: str = "MySQL"
connect_url = "mysql+pymysql://" connect_url = "mysql+pymysql://"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -106,7 +106,7 @@ class Conversation:
conv_default = Conversation( conv_default = Conversation(
system = None, system=None,
roles=("human", "ai"), roles=("human", "ai"),
messages=[], messages=[],
offset=0, offset=0,
@ -298,7 +298,6 @@ chat_mode_title = {
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"), "sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
"chat_use_plugin": get_lang_text("chat_use_plugin"), "chat_use_plugin": get_lang_text("chat_use_plugin"),
"knowledge_qa": get_lang_text("knowledge_qa"), "knowledge_qa": get_lang_text("knowledge_qa"),
} }
conversation_sql_mode = { conversation_sql_mode = {

View File

@ -29,7 +29,6 @@ lang_dicts = {
"url_input_label": "输入网页地址", "url_input_label": "输入网页地址",
"add_as_new_klg": "添加为新知识库", "add_as_new_klg": "添加为新知识库",
"add_file_to_klg": "向知识库中添加文件", "add_file_to_klg": "向知识库中添加文件",
"upload_file": "上传文件", "upload_file": "上传文件",
"add_file": "添加文件", "add_file": "添加文件",
"upload_and_load_to_klg": "上传并加载到知识库", "upload_and_load_to_klg": "上传并加载到知识库",

View File

@ -9,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@torch.inference_mode() @torch.inference_mode()
def chatglm_generate_stream( def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2 model, tokenizer, params, device, context_len=2048, stream_interval=2
): ):
"""Generate text using chatglm model's chat api""" """Generate text using chatglm model's chat api"""
prompt = params["prompt"] prompt = params["prompt"]
@ -57,7 +57,7 @@ def chatglm_generate_stream(
# i = 0 # i = 0
for i, (response, new_hist) in enumerate( for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs) model.stream_chat(tokenizer, query, hist, **generate_kwargs)
): ):
if echo: if echo:
output = query + " " + response output = query + " " + response

View File

@ -83,9 +83,7 @@ class BaseOutputParser(ABC):
output = self.__post_process_code(output) output = self.__post_process_code(output)
yield output yield output
else: else:
output = ( output = data["text"] + f" (error_code: {data['error_code']})"
data["text"] + f" (error_code: {data['error_code']})"
)
yield output yield output
def parse_model_nostream_resp(self, response, sep: str): def parse_model_nostream_resp(self, response, sep: str):

View File

@ -133,10 +133,8 @@ class PluginPromptGenerator:
else: else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_commands_string(self) -> str:
def generate_commands_string(self)->str: return f"{self._generate_numbered_list(self.commands, item_type='command')}"
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
def generate_prompt_string(self) -> str: def generate_prompt_string(self) -> str:
""" """

View File

@ -33,13 +33,13 @@ class PromptTemplate(BaseModel, ABC):
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""
template_scene: Optional[str] template_scene: Optional[str]
template_define: Optional[str] template_define: Optional[str]
"""this template define""" """this template define"""
template: Optional[str] template: Optional[str]
"""The prompt template.""" """The prompt template."""
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str] response_format: Optional[str]
"""default use stream out""" """default use stream out"""
stream_out: bool = True stream_out: bool = True
"""""" """"""
@ -62,7 +62,9 @@ class PromptTemplate(BaseModel, ABC):
if self.template: if self.template:
if self.response_format: if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4) kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) return DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template, **kwargs
)
def add_goals(self, goal: str) -> None: def add_goals(self, goal: str) -> None:
self.goals.append(goal) self.goals.append(goal)

View File

@ -129,7 +129,7 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -141,7 +141,7 @@ class BaseChat(ABC):
stream=True, stream=True,
timeout=120, timeout=120,
) )
return response; return response
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)

View File

@ -21,15 +21,21 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatWithDbExecute, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatWithDbExecute,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
if not db_name: if not db_name:
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!") raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
self.db_name = db_name self.db_name = db_name
self.database = CFG.local_db self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接) # 准备DB信息(拿到指定库的链接)
@ -40,9 +46,7 @@ class ChatWithDbAutoExecute(BaseChat):
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
raise ValueError( raise ValueError("Could not import DBSummaryClient. ")
"Could not import DBSummaryClient. "
)
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), "top_k": str(self.top_k),

View File

@ -20,9 +20,8 @@ class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out) super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text); clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str) print("clean prompt response:", clean_str)
response = json.loads(clean_str) response = json.loads(clean_str)
sql, thoughts = response["sql"], response["thoughts"] sql, thoughts = response["sql"], response["thoughts"]

View File

@ -62,4 +62,3 @@ prompt = PromptTemplate(
), ),
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -19,13 +19,17 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input): def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatWithDbQA, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatWithDbQA,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.db_name = db_name self.db_name = db_name
if db_name: if db_name:
self.database = CFG.local_db self.database = CFG.local_db
@ -34,17 +38,16 @@ class ChatWithDbQA(BaseChat):
self.top_k: int = 5 self.top_k: int = 5
def generate_input_values(self): def generate_input_values(self):
table_info = "" table_info = ""
dialect = "mysql" dialect = "mysql"
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
raise ValueError( raise ValueError("Could not import DBSummaryClient. ")
"Could not import DBSummaryClient. "
)
if self.db_name: if self.db_name:
table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) table_info = DBSummaryClient.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
# table_info = self.database.table_simple_info(self.db_connect) # table_info = self.database.table_simple_info(self.db_connect)
dialect = self.database.dialect dialect = self.database.dialect
@ -52,7 +55,7 @@ class ChatWithDbQA(BaseChat):
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), "top_k": str(self.top_k),
"dialect": dialect, "dialect": dialect,
"table_info": table_info "table_info": table_info,
} }
return input_values return input_values

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -27,7 +27,6 @@ Pay attention to use only the column names that you can see in the schema descri
""" """
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True PROMPT_NEED_NEED_STREAM_OUT = True
@ -37,7 +36,7 @@ prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k"], input_variables=["input", "table_info", "dialect", "top_k"],
response_format=None, response_format=None,
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX , template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser( output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
@ -45,4 +44,3 @@ prompt = PromptTemplate(
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -14,56 +14,72 @@ from pilot.scene.chat_execution.prompt import prompt
CFG = Config() CFG = Config()
class ChatWithPlugin(BaseChat): class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatExecution.value chat_scene: str = ChatScene.ChatExecution.value
plugins_prompt_generator:PluginPromptGenerator plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None select_plugin: str = None
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None): def __init__(
super().__init__(temperature=temperature, self,
max_new_tokens=max_new_tokens, temperature,
chat_mode=ChatScene.ChatExecution, max_new_tokens,
chat_session_id=chat_session_id, chat_session_id,
current_user_input=user_input) user_input,
plugin_selector: str = None,
):
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令 # 加载插件中可用命令
self.select_plugin = plugin_selector self.select_plugin = plugin_selector
if self.select_plugin: if self.select_plugin:
for plugin in CFG.plugins: for plugin in CFG.plugins:
if plugin._name == plugin_selector : if plugin._name == plugin_selector:
if not plugin.can_handle_post_prompt(): if not plugin.can_handle_post_prompt():
continue continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
else: else:
for plugin in CFG.plugins: for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt(): if not plugin.can_handle_post_prompt():
continue continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)), "constraints": self.__list_to_prompt_str(
"commands_infos": self.plugins_prompt_generator.generate_commands_string() list(self.plugins_prompt_generator.constraints)
),
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
## plugin command run ## plugin command run
return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator) return execute_command(
str(prompt_response.command.get("name")),
prompt_response.command.get("args", {}),
self.plugins_prompt_generator,
)
def chat_show(self): def chat_show(self):
super().chat_show() super().chat_show()
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
if list: if list:
separator = '\n' separator = "\n"
return separator.join(list) return separator.join(list)
else: else:
return "" return ""

View File

@ -10,14 +10,13 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple): class PluginAction(NamedTuple):
command: Dict command: Dict
thoughts: Dict thoughts: Dict
class PluginChatOutputParser(BaseOutputParser): class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text)) response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts = response["command"], response["thoughts"] command, thoughts = response["command"], response["thoughts"]
@ -25,7 +24,7 @@ class PluginChatOutputParser(BaseOutputParser):
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### tool out data to table view ### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}" ) print(f"parse_view_response:{speak},{str(data)}")
view_text = f"##### {speak}" + "\n" + str(data) view_text = f"##### {speak}" + "\n" + str(data)
return view_text return view_text

View File

@ -53,7 +53,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value, template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"], input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4), response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,

View File

@ -11,6 +11,7 @@ from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
class ChatFactory(metaclass=Singleton): class ChatFactory(metaclass=Singleton):
@staticmethod @staticmethod
def get_implementation(chat_mode, **kwargs): def get_implementation(chat_mode, **kwargs):

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -24,18 +23,22 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config() CFG = Config()
class ChatNewKnowledge (BaseChat): class ChatNewKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatNewKnowledge.value chat_scene: str = ChatScene.ChatNewKnowledge.value
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name): def __init__(
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatNewKnowledge, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatNewKnowledge,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.knowledge_name = knowledge_name self.knowledge_name = knowledge_name
vector_store_config = { vector_store_config = {
"vector_store_name": knowledge_name, "vector_store_name": knowledge_name,
@ -49,21 +52,17 @@ class ChatNewKnowledge (BaseChat):
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000] docs = docs[:2000]
input_values = { input_values = {"context": docs, "question": self.current_user_input}
"context": docs,
"question": self.current_user_input
}
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return prompt_response return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value return ChatScene.ChatNewKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -23,7 +23,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
""" """
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True PROMPT_NEED_NEED_STREAM_OUT = True
@ -42,5 +41,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -24,43 +23,42 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config() CFG = Config()
class ChatDefaultKnowledge (BaseChat): class ChatDefaultKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value chat_scene: str = ChatScene.ChatKnowledge.value
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatKnowledge, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatKnowledge,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
vector_store_config = { vector_store_config = {
"vector_store_name": "default", "vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="", file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False, local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000] docs = docs[:2000]
input_values = { input_values = {"context": docs, "question": self.current_user_input}
"context": docs,
"question": self.current_user_input
}
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return prompt_response return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value return ChatScene.ChatKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
""" """
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True PROMPT_NEED_NEED_STREAM_OUT = True
@ -39,5 +38,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
@ -8,34 +7,42 @@ from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
CFG = Config() CFG = Config()
class InnerChatDBSummary (BaseChat): class InnerChatDBSummary(BaseChat):
chat_scene: str = ChatScene.InnerChatDBSummary.value chat_scene: str = ChatScene.InnerChatDBSummary.value
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary): def __init__(
self,
temperature,
max_new_tokens,
chat_session_id,
user_input,
db_select,
db_summary,
):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.InnerChatDBSummary, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.InnerChatDBSummary,
current_user_input=user_input) chat_session_id=chat_session_id,
self.db_name = db_select current_user_input=user_input,
self.db_summary = db_summary )
self.db_input = db_select
self.db_summary = db_summary
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
"db_input": self.db_name, "db_input": self.db_input,
"db_profile_summary": self.db_summary "db_profile_summary": self.db_summary,
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return prompt_response return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value return ChatScene.InnerChatDBSummary.value

View File

@ -10,13 +10,15 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: class NormalChatOutputParser(BaseOutputParser):
return model_out_text def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
return clean_str
def parse_view_response(self, ai_text, data) -> str: def parse_view_response(self, ai_text, data) -> str:
return ai_text["table"] return ai_text
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -7,33 +7,30 @@ from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser from pilot.scene.chat_knowledge.inner_db_summary.out_parser import (
NormalChatOutputParser,
)
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE ="""""" PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE = """
Based on the following known database information?, answer which tables are involved in the user input. Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary} Known database information:{db_profile_summary}
Input:{db_input} Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
The response format must be JSON, and the key of JSON must be "table".
""" """
PROMPT_RESPONSE = """You must respond in JSON format as following format: PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response} {response}
The response format must be JSON, and the key of JSON must be "table".
Ensure the response is correct json and can be parsed by Python json.loads
""" """
RESPONSE_FORMAT = {"table": ["orders", "products"]}
RESPONSE_FORMAT = {
"table": ["orders", "products"]
}
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
@ -54,5 +51,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -24,18 +23,20 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config() CFG = Config()
class ChatUrlKnowledge (BaseChat): class ChatUrlKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatUrlKnowledge.value chat_scene: str = ChatScene.ChatUrlKnowledge.value
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url): def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatUrlKnowledge, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatUrlKnowledge,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.url = url self.url = url
vector_store_config = { vector_store_config = {
"vector_store_name": url, "vector_store_name": url,
@ -54,19 +55,16 @@ class ChatUrlKnowledge (BaseChat):
self.knowledge_embedding_client.knowledge_embedding() self.knowledge_embedding_client.knowledge_embedding()
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K) docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000] docs = docs[:2000]
input_values = { input_values = {"context": docs, "question": self.current_user_input}
"context": docs,
"question": self.current_user_input
}
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return prompt_response return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value return ChatScene.ChatUrlKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
""" """
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True PROMPT_NEED_NEED_STREAM_OUT = True
@ -39,5 +38,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -19,25 +18,23 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input): def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
""" """ """ """
super().__init__(temperature=temperature, super().__init__(
max_new_tokens=max_new_tokens, temperature=temperature,
chat_mode=ChatScene.ChatNormal, max_new_tokens=max_new_tokens,
chat_session_id=chat_session_id, chat_mode=ChatScene.ChatNormal,
current_user_input=user_input) chat_session_id=chat_session_id,
current_user_input=user_input,
)
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {"input": self.current_user_input}
"input": self.current_user_input
}
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return prompt_response return prompt_response
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNormal.value return ChatScene.ChatNormal.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text

View File

@ -29,5 +29,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -85,9 +85,7 @@ add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue" "knowledge_qa_type_add_knowledge_base_dialogue"
) )
url_knowledge_dialogue = get_lang_text( url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
"knowledge_qa_type_url_knowledge_dialogue"
)
knowledge_qa_type_list = [ knowledge_qa_type_list = [
llm_native_dialogue, llm_native_dialogue,
@ -205,9 +203,9 @@ def post_process_code(code):
def get_chat_mode(selected, param=None) -> ChatScene: def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title['chat_use_plugin'] == selected: if chat_mode_title["chat_use_plugin"] == selected:
return ChatScene.ChatExecution return ChatScene.ChatExecution
elif chat_mode_title['knowledge_qa'] == selected: elif chat_mode_title["knowledge_qa"] == selected:
mode = param mode = param
if mode == conversation_types["default_knownledge"]: if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge return ChatScene.ChatKnowledge
@ -232,14 +230,23 @@ def chatbot_callback(state, message):
def http_bot( def http_bot(
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, state,
knowledge_name selected,
temperature,
max_new_tokens,
plugin_selector,
mode,
sql_mode,
db_selector,
url_input,
knowledge_name,
): ):
logger.info( logger.info(
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}") f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
if chat_mode_title['knowledge_qa'] == selected: )
if chat_mode_title["knowledge_qa"] == selected:
scene: ChatScene = get_chat_mode(selected, mode) scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected: elif chat_mode_title["chat_use_plugin"] == selected:
scene: ChatScene = get_chat_mode(selected) scene: ChatScene = get_chat_mode(selected)
else: else:
scene: ChatScene = get_chat_mode(selected, sql_mode) scene: ChatScene = get_chat_mode(selected, sql_mode)
@ -251,7 +258,7 @@ def http_bot(
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"db_name": db_selector, "db_name": db_selector,
"user_input": state.last_user_input "user_input": state.last_user_input,
} }
elif ChatScene.ChatWithDbQA == scene: elif ChatScene.ChatWithDbQA == scene:
chat_param = { chat_param = {
@ -289,7 +296,7 @@ def http_bot(
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"user_input": state.last_user_input, "user_input": state.last_user_input,
"knowledge_name": knowledge_name "knowledge_name": knowledge_name,
} }
elif ChatScene.ChatUrlKnowledge == scene: elif ChatScene.ChatUrlKnowledge == scene:
chat_param = { chat_param = {
@ -297,7 +304,7 @@ def http_bot(
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"user_input": state.last_user_input, "user_input": state.last_user_input,
"url": url_input "url": url_input,
} }
else: else:
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}" state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
@ -314,7 +321,11 @@ def http_bot(
response = chat.stream_call() response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len) state.messages[-1][
-1
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -323,8 +334,8 @@ def http_bot(
block_css = ( block_css = (
code_highlight_css code_highlight_css
+ """ + """
pre { pre {
white-space: pre-wrap; /* Since CSS 2.1 */ white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -361,7 +372,7 @@ def build_single_model_ui():
gr.Markdown(notice_markdown, elem_id="notice_markdown") gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion( with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row: ) as parameter_row:
temperature = gr.Slider( temperature = gr.Slider(
minimum=0.0, minimum=0.0,
@ -411,7 +422,7 @@ def build_single_model_ui():
get_lang_text("sql_generate_mode_none"), get_lang_text("sql_generate_mode_none"),
], ],
show_label=False, show_label=False,
value=get_lang_text("sql_generate_mode_none") value=get_lang_text("sql_generate_mode_none"),
) )
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
@ -428,15 +439,19 @@ def build_single_model_ui():
value="", value="",
interactive=True, interactive=True,
show_label=True, show_label=True,
type="value" type="value",
).style(container=False) ).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData def plugin_change(
evt: gr.SelectData,
): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}") print(f"You selected {evt.value} at {evt.index} from {evt.target}")
print(f"user plugin:{plugins_select_info().get(evt.value)}") print(f"user plugin:{plugins_select_info().get(evt.value)}")
return plugins_select_info().get(evt.value) return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") plugin_selected = gr.Textbox(
show_label=False, visible=False, placeholder="Selected"
)
plugin_selector.select(plugin_change, None, plugin_selected) plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA") tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
@ -456,7 +471,12 @@ def build_single_model_ui():
) )
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False) url_input = gr.Textbox(
label=get_lang_text("url_input_label"),
lines=1,
interactive=True,
visible=False,
)
def show_url_input(evt: gr.SelectData): def show_url_input(evt: gr.SelectData):
if evt.value == url_knowledge_dialogue: if evt.value == url_knowledge_dialogue:
@ -559,10 +579,10 @@ def build_single_model_ui():
def build_webdemo(): def build_webdemo():
with gr.Blocks( with gr.Blocks(
title=get_lang_text("database_smart_assistant"), title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(), # theme=gr.themes.Base(),
theme=gr.themes.Default(), theme=gr.themes.Default(),
css=block_css, css=block_css,
) as demo: ) as demo:
url_params = gr.JSON(visible=False) url_params = gr.JSON(visible=False)
( (

View File

@ -18,7 +18,14 @@ CFG = Config()
class KnowledgeEmbedding: class KnowledgeEmbedding:
def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"): def __init__(
self,
file_path,
model_name,
vector_store_config,
local_persist=True,
file_type="default",
):
"""Initialize with Loader url, model_name, vector_store_config""" """Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path self.file_path = file_path
self.model_name = model_name self.model_name = model_name
@ -63,7 +70,6 @@ class KnowledgeEmbedding:
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
) )
elif self.file_type == "default": elif self.file_type == "default":
embedding = MarkdownEmbedding( embedding = MarkdownEmbedding(
file_path=self.file_path, file_path=self.file_path,
@ -71,7 +77,6 @@ class KnowledgeEmbedding:
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
) )
return embedding return embedding
def similar_search(self, text, topk): def similar_search(self, text, topk):

View File

@ -13,6 +13,7 @@ from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
CFG = Config() CFG = Config()
chat_factory = ChatFactory()
class DBSummaryClient: class DBSummaryClient:
@ -88,13 +89,18 @@ class DBSummaryClient:
) )
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk) table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs] related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else: else:
table_docs = knowledge_embedding_client.similar_search(query, 1) table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt( # prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content # query, table_docs[0].page_content
# ) # )
related_tables = _get_llm_response(query, dbname, table_docs[0].page_content) related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = [] related_table_summaries = []
for table in related_tables: for table in related_tables:
vector_store_config = { vector_store_config = {
@ -118,35 +124,14 @@ def _get_llm_response(query, db_input, dbsummary):
"max_new_tokens": 512, "max_new_tokens": 512,
"chat_session_id": uuid.uuid1(), "chat_session_id": uuid.uuid1(),
"user_input": query, "user_input": query,
"db_input": db_input, "db_select": db_input,
"db_summary": dbsummary, "db_summary": dbsummary,
} }
chat_factory = ChatFactory() chat: BaseChat = chat_factory.get_implementation(
chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param) ChatScene.InnerChatDBSummary.value, **chat_param
)
return chat.call() res = chat.nostream_call()
# payload = { return json.loads(res)["table"]
# "model": CFG.LLM_MODEL,
# "prompt": prompt,
# "temperature": float(0.7),
# "max_new_tokens": int(512),
# "stop": state.sep
# if state.sep_style == SeparatorStyle.SINGLE
# else state.sep2,
# }
# headers = {"User-Agent": "dbgpt Client"}
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate"),
# headers=headers,
# json=payload,
# timeout=120,
# )
#
# print(related_tables)
# return related_tables
# except NotCommands as e:
# print("llm response error:" + e.message)
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@ -37,20 +37,23 @@ class MysqlSummary(DBSummary):
table_name=table_comment[0], table_comment=table_comment[1] table_name=table_comment[0], table_comment=table_comment[1]
) )
) )
vector_table = json.dumps( vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]} {"table_name": table_comment[0], "table_description": table_comment[1]}
) )
self.vector_tables_info.append( self.vector_tables_info.append(
vector_table.encode("utf-8").decode("unicode_escape") vector_table.encode("utf-8").decode("unicode_escape")
) )
self.table_columns_info = []
for table_name in tables: for table_name in tables:
table_summary = MysqlTableSummary(self.db, name, table_name) table_summary = MysqlTableSummary(self.db, name, table_name)
self.tables[table_name] = table_summary.get_summery() # self.tables[table_name] = table_summary.get_summery()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summery())
def get_summery(self): def get_summery(self):
if CFG.SUMMARY_CONFIG == "VECTOR": if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info return self.vector_tables_info
else: else:
return self.summery.format( return self.summery.format(
@ -63,6 +66,9 @@ class MysqlSummary(DBSummary):
def get_table_comments(self): def get_table_comments(self):
return self.table_comments return self.table_comments
def get_columns(self):
return self.table_columns_info
class MysqlTableSummary(TableSummary): class MysqlTableSummary(TableSummary):
"""Get mysql table summary template.""" """Get mysql table summary template."""
@ -78,10 +84,16 @@ class MysqlTableSummary(TableSummary):
self.db = instance self.db = instance
fields = self.db.get_fields(name) fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name) indexes = self.db.get_indexes(name)
field_names = []
for field in fields: for field in fields:
field_summary = MysqlFieldsSummary(field) field_summary = MysqlFieldsSummary(field)
self.fields.append(field_summary) self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery()) self.fields_info.append(field_summary.get_summery())
field_names.append(field[0])
self.column_summery = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes: for index in indexes:
index_summary = MysqlIndexSummary(index) index_summary = MysqlIndexSummary(index)
@ -96,6 +108,9 @@ class MysqlTableSummary(TableSummary):
indexes=";".join(self.indexes_info), indexes=";".join(self.indexes_info),
) )
def get_columns(self):
return self.column_summery
class MysqlFieldsSummary(FieldSummary): class MysqlFieldsSummary(FieldSummary):
"""Get mysql field summary template.""" """Get mysql field summary template."""

View File

@ -1,4 +1,5 @@
from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.chroma_store import ChromaStore
# from pilot.vector_store.milvus_store import MilvusStore # from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": None} connector = {"Chroma": ChromaStore, "Milvus": None}