mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-24 19:13:33 +00:00
Merge branch 'summary' into dev
# Conflicts: # pilot/model/llm_out/guanaco_llm.py # pilot/out_parser/base.py # pilot/scene/base_chat.py
This commit is contained in:
commit
7e6cf9c9e0
@ -55,8 +55,6 @@ def fix_and_parse_json(
|
||||
logger.error("参数解析错误", e)
|
||||
|
||||
|
||||
|
||||
|
||||
def correct_json(json_to_load: str) -> str:
|
||||
"""
|
||||
Correct common JSON errors.
|
||||
|
@ -1,6 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from collections import deque
|
||||
|
||||
|
||||
class FixedSizeDict(OrderedDict):
|
||||
def __init__(self, max_size):
|
||||
super().__init__()
|
||||
@ -11,6 +12,7 @@ class FixedSizeDict(OrderedDict):
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class FixedSizeList:
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
@ -29,4 +31,4 @@ class FixedSizeList:
|
||||
return len(self.list)
|
||||
|
||||
def __str__(self):
|
||||
return str(list(self.list))
|
||||
return str(list(self.list))
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import sqlparse
|
||||
import regex as re
|
||||
import regex as re
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
@ -283,16 +283,16 @@ class Database:
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self.get_session_db(session)
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
#TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self.get_session_db(session)
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self,session, query, fetch: str = "all"):
|
||||
def __query(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
@ -390,37 +390,44 @@ class Database:
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
if cmd_type == 'insert':
|
||||
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
|
||||
if cmd_type == "insert":
|
||||
match = re.match(
|
||||
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||
)
|
||||
if match:
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(',')
|
||||
values = values.split(',')
|
||||
columns = columns.split(",")
|
||||
values = values.split(",")
|
||||
# 构造 WHERE 子句
|
||||
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
|
||||
return f'SELECT * FROM {table_name} WHERE {where_clause}'
|
||||
where_clause = " AND ".join(
|
||||
[
|
||||
f"{col.strip()}={val.strip()}"
|
||||
for col, val in zip(columns, values)
|
||||
]
|
||||
)
|
||||
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||
|
||||
elif cmd_type == 'delete':
|
||||
elif cmd_type == "delete":
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
return f'SELECT * FROM {table_name}'
|
||||
return f"SELECT * FROM {table_name}"
|
||||
|
||||
elif cmd_type == 'update':
|
||||
elif cmd_type == "update":
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index('set')
|
||||
where_idx = parts.index('where')
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
|
||||
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = ' '.join(parts[where_idx + 1:])
|
||||
where_clause = " ".join(parts[where_idx + 1 :])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
|
||||
@ -429,8 +436,6 @@ class Database:
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||
return parsed, ttype, sql_type
|
||||
|
||||
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
|
@ -103,8 +103,12 @@ class Config(metaclass=Singleton):
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
### Native SQL Execution Capability Control Configuration
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True"
|
||||
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True"
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
|
||||
)
|
||||
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||
)
|
||||
|
||||
### Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
|
@ -11,11 +11,9 @@ class BaseConnect(BaseModel, ABC):
|
||||
type
|
||||
driver: str
|
||||
|
||||
|
||||
def get_session(self, db_name: str):
|
||||
pass
|
||||
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
pass
|
||||
|
||||
@ -32,4 +30,4 @@ class BaseConnect(BaseModel, ABC):
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
pass
|
||||
pass
|
||||
|
@ -11,8 +11,7 @@ class MySQLConnect(RDBMSDatabase):
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type:str = "MySQL"
|
||||
type: str = "MySQL"
|
||||
connect_url = "mysql+pymysql://"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
||||
|
@ -106,7 +106,7 @@ class Conversation:
|
||||
|
||||
|
||||
conv_default = Conversation(
|
||||
system = None,
|
||||
system=None,
|
||||
roles=("human", "ai"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
@ -298,7 +298,6 @@ chat_mode_title = {
|
||||
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
|
||||
"chat_use_plugin": get_lang_text("chat_use_plugin"),
|
||||
"knowledge_qa": get_lang_text("knowledge_qa"),
|
||||
|
||||
}
|
||||
|
||||
conversation_sql_mode = {
|
||||
|
@ -29,7 +29,6 @@ lang_dicts = {
|
||||
"url_input_label": "输入网页地址",
|
||||
"add_as_new_klg": "添加为新知识库",
|
||||
"add_file_to_klg": "向知识库中添加文件",
|
||||
|
||||
"upload_file": "上传文件",
|
||||
"add_file": "添加文件",
|
||||
"upload_and_load_to_klg": "上传并加载到知识库",
|
||||
|
@ -9,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
@torch.inference_mode()
|
||||
def chatglm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
"""Generate text using chatglm model's chat api"""
|
||||
prompt = params["prompt"]
|
||||
@ -57,7 +57,7 @@ def chatglm_generate_stream(
|
||||
# i = 0
|
||||
|
||||
for i, (response, new_hist) in enumerate(
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
):
|
||||
if echo:
|
||||
output = query + " " + response
|
||||
|
@ -83,9 +83,7 @@ class BaseOutputParser(ABC):
|
||||
output = self.__post_process_code(output)
|
||||
yield output
|
||||
else:
|
||||
output = (
|
||||
data["text"] + f" (error_code: {data['error_code']})"
|
||||
)
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
yield output
|
||||
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
|
@ -133,10 +133,8 @@ class PluginPromptGenerator:
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
|
||||
def generate_commands_string(self)->str:
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||
|
||||
def generate_prompt_string(self) -> str:
|
||||
"""
|
||||
|
@ -33,13 +33,13 @@ class PromptTemplate(BaseModel, ABC):
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
template_scene: Optional[str]
|
||||
|
||||
template_define: Optional[str]
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
template: Optional[str]
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format: Optional[str]
|
||||
response_format: Optional[str]
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
@ -62,7 +62,9 @@ class PromptTemplate(BaseModel, ABC):
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template, **kwargs
|
||||
)
|
||||
|
||||
def add_goals(self, goal: str) -> None:
|
||||
self.goals.append(goal)
|
||||
|
@ -129,7 +129,7 @@ class BaseChat(ABC):
|
||||
def stream_call(self):
|
||||
payload = self.__call_base()
|
||||
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
@ -141,7 +141,7 @@ class BaseChat(ABC):
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
return response;
|
||||
return response
|
||||
|
||||
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
||||
|
||||
|
@ -21,15 +21,21 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||
):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
if not db_name:
|
||||
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
|
||||
raise ValueError(
|
||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||
)
|
||||
self.db_name = db_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
@ -40,9 +46,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
try:
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import DBSummaryClient. "
|
||||
)
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
|
@ -20,9 +20,8 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text);
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
sql, thoughts = response["sql"], response["thoughts"]
|
||||
|
@ -62,4 +62,3 @@ prompt = PromptTemplate(
|
||||
),
|
||||
)
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
@ -19,13 +19,17 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
|
||||
):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
self.db_name = db_name
|
||||
if db_name:
|
||||
self.database = CFG.local_db
|
||||
@ -34,17 +38,16 @@ class ChatWithDbQA(BaseChat):
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
|
||||
table_info = ""
|
||||
dialect = "mysql"
|
||||
try:
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import DBSummaryClient. "
|
||||
)
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
if self.db_name:
|
||||
table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
table_info = DBSummaryClient.get_similar_tables(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
# table_info = self.database.table_simple_info(self.db_connect)
|
||||
dialect = self.database.dialect
|
||||
|
||||
@ -52,7 +55,7 @@ class ChatWithDbQA(BaseChat):
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": dialect,
|
||||
"table_info": table_info
|
||||
"table_info": table_info,
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
|
@ -27,7 +27,6 @@ Pay attention to use only the column names that you can see in the schema descri
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
@ -37,7 +36,7 @@ prompt = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
@ -45,4 +44,3 @@ prompt = PromptTemplate(
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
@ -14,56 +14,72 @@ from pilot.scene.chat_execution.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithPlugin(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatExecution.value
|
||||
plugins_prompt_generator:PluginPromptGenerator
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
plugin_selector: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# 加载插件中可用命令
|
||||
self.select_plugin = plugin_selector
|
||||
if self.select_plugin:
|
||||
for plugin in CFG.plugins:
|
||||
if plugin._name == plugin_selector :
|
||||
if plugin._name == plugin_selector:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
else:
|
||||
for plugin in CFG.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
||||
|
||||
|
||||
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
|
||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
|
||||
"constraints": self.__list_to_prompt_str(
|
||||
list(self.plugins_prompt_generator.constraints)
|
||||
),
|
||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
## plugin command run
|
||||
return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator)
|
||||
return execute_command(
|
||||
str(prompt_response.command.get("name")),
|
||||
prompt_response.command.get("args", {}),
|
||||
self.plugins_prompt_generator,
|
||||
)
|
||||
|
||||
def chat_show(self):
|
||||
super().chat_show()
|
||||
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
if list:
|
||||
separator = '\n'
|
||||
separator = "\n"
|
||||
return separator.join(list)
|
||||
else:
|
||||
return ""
|
||||
|
@ -10,14 +10,13 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
thoughts: Dict
|
||||
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
command, thoughts = response["command"], response["thoughts"]
|
||||
@ -25,7 +24,7 @@ class PluginChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
print(f"parse_view_response:{speak},{str(data)}" )
|
||||
print(f"parse_view_response:{speak},{str(data)}")
|
||||
view_text = f"##### {speak}" + "\n" + str(data)
|
||||
return view_text
|
||||
|
||||
|
@ -53,7 +53,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExecution.value,
|
||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
@ -63,4 +63,4 @@ prompt = PromptTemplate(
|
||||
),
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
@ -11,6 +11,7 @@ from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
|
||||
|
||||
class ChatFactory(metaclass=Singleton):
|
||||
@staticmethod
|
||||
def get_implementation(chat_mode, **kwargs):
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -24,18 +23,22 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatNewKnowledge (BaseChat):
|
||||
class ChatNewKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatNewKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
|
||||
def __init__(
|
||||
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
|
||||
):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
self.knowledge_name = knowledge_name
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_name,
|
||||
@ -49,21 +52,17 @@ class ChatNewKnowledge (BaseChat):
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
input_values = {"context": docs, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNewKnowledge.value
|
||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
|
@ -23,7 +23,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
@ -42,5 +41,3 @@ prompt = PromptTemplate(
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -24,43 +23,42 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatDefaultKnowledge (BaseChat):
|
||||
class ChatDefaultKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": "default",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
input_values = {"context": docs, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value
|
||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
|
@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
@ -39,5 +38,3 @@ prompt = PromptTemplate(
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@ -8,34 +7,42 @@ from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class InnerChatDBSummary (BaseChat):
|
||||
class InnerChatDBSummary(BaseChat):
|
||||
chat_scene: str = ChatScene.InnerChatDBSummary.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary):
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
db_select,
|
||||
db_summary,
|
||||
):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
self.db_name = db_select
|
||||
self.db_summary = db_summary
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
|
||||
self.db_input = db_select
|
||||
self.db_summary = db_summary
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"db_input": self.db_name,
|
||||
"db_profile_summary": self.db_summary
|
||||
"db_input": self.db_input,
|
||||
"db_profile_summary": self.db_summary,
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.InnerChatDBSummary.value
|
||||
|
@ -10,13 +10,15 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
return clean_str
|
||||
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
return ai_text["table"]
|
||||
return ai_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
||||
|
@ -7,33 +7,30 @@ from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import (
|
||||
NormalChatOutputParser,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE =""""""
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Based on the following known database information?, answer which tables are involved in the user input.
|
||||
Known database information:{db_profile_summary}
|
||||
Input:{db_input}
|
||||
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
|
||||
The response format must be JSON, and the key of JSON must be "table".
|
||||
|
||||
|
||||
"""
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
The response format must be JSON, and the key of JSON must be "table".
|
||||
"""
|
||||
|
||||
|
||||
|
||||
RESPONSE_FORMAT = {
|
||||
"table": ["orders", "products"]
|
||||
}
|
||||
|
||||
RESPONSE_FORMAT = {"table": ["orders", "products"]}
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
@ -54,5 +51,3 @@ prompt = PromptTemplate(
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -24,18 +23,20 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatUrlKnowledge (BaseChat):
|
||||
class ChatUrlKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatUrlKnowledge.value
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatUrlKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
@ -54,19 +55,16 @@ class ChatUrlKnowledge (BaseChat):
|
||||
self.knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
)
|
||||
docs = docs[:2000]
|
||||
input_values = {
|
||||
"context": docs,
|
||||
"question": self.current_user_input
|
||||
}
|
||||
input_values = {"context": docs, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatUrlKnowledge.value
|
||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
|
@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
|
||||
"""
|
||||
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
@ -39,5 +38,3 @@ prompt = PromptTemplate(
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -19,25 +18,23 @@ class ChatNormal(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
|
||||
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input
|
||||
}
|
||||
input_values = {"input": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNormal.value
|
||||
|
@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
|
@ -29,5 +29,3 @@ prompt = PromptTemplate(
|
||||
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
||||
|
||||
|
@ -85,9 +85,7 @@ add_knowledge_base_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
||||
)
|
||||
|
||||
url_knowledge_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_url_knowledge_dialogue"
|
||||
)
|
||||
url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
|
||||
|
||||
knowledge_qa_type_list = [
|
||||
llm_native_dialogue,
|
||||
@ -205,9 +203,9 @@ def post_process_code(code):
|
||||
|
||||
|
||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
||||
if chat_mode_title['chat_use_plugin'] == selected:
|
||||
if chat_mode_title["chat_use_plugin"] == selected:
|
||||
return ChatScene.ChatExecution
|
||||
elif chat_mode_title['knowledge_qa'] == selected:
|
||||
elif chat_mode_title["knowledge_qa"] == selected:
|
||||
mode = param
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
@ -232,14 +230,23 @@ def chatbot_callback(state, message):
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input,
|
||||
knowledge_name
|
||||
state,
|
||||
selected,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
plugin_selector,
|
||||
mode,
|
||||
sql_mode,
|
||||
db_selector,
|
||||
url_input,
|
||||
knowledge_name,
|
||||
):
|
||||
logger.info(
|
||||
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
|
||||
if chat_mode_title['knowledge_qa'] == selected:
|
||||
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
|
||||
)
|
||||
if chat_mode_title["knowledge_qa"] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected, mode)
|
||||
elif chat_mode_title['chat_use_plugin'] == selected:
|
||||
elif chat_mode_title["chat_use_plugin"] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected)
|
||||
else:
|
||||
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
||||
@ -251,7 +258,7 @@ def http_bot(
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
elif ChatScene.ChatWithDbQA == scene:
|
||||
chat_param = {
|
||||
@ -289,7 +296,7 @@ def http_bot(
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"knowledge_name": knowledge_name
|
||||
"knowledge_name": knowledge_name,
|
||||
}
|
||||
elif ChatScene.ChatUrlKnowledge == scene:
|
||||
chat_param = {
|
||||
@ -297,7 +304,7 @@ def http_bot(
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"url": url_input
|
||||
"url": url_input,
|
||||
}
|
||||
else:
|
||||
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
|
||||
@ -314,7 +321,11 @@ def http_bot(
|
||||
response = chat.stream_call()
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len)
|
||||
state.messages[-1][
|
||||
-1
|
||||
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -323,8 +334,8 @@ def http_bot(
|
||||
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
@ -361,7 +372,7 @@ def build_single_model_ui():
|
||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||
|
||||
with gr.Accordion(
|
||||
get_lang_text("model_control_param"), open=False, visible=False
|
||||
get_lang_text("model_control_param"), open=False, visible=False
|
||||
) as parameter_row:
|
||||
temperature = gr.Slider(
|
||||
minimum=0.0,
|
||||
@ -411,7 +422,7 @@ def build_single_model_ui():
|
||||
get_lang_text("sql_generate_mode_none"),
|
||||
],
|
||||
show_label=False,
|
||||
value=get_lang_text("sql_generate_mode_none")
|
||||
value=get_lang_text("sql_generate_mode_none"),
|
||||
)
|
||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
@ -428,15 +439,19 @@ def build_single_model_ui():
|
||||
value="",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
type="value"
|
||||
type="value",
|
||||
).style(container=False)
|
||||
|
||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
def plugin_change(
|
||||
evt: gr.SelectData,
|
||||
): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
print(f"user plugin:{plugins_select_info().get(evt.value)}")
|
||||
return plugins_select_info().get(evt.value)
|
||||
|
||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
plugin_selected = gr.Textbox(
|
||||
show_label=False, visible=False, placeholder="Selected"
|
||||
)
|
||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||
|
||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
||||
@ -456,7 +471,12 @@ def build_single_model_ui():
|
||||
)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
|
||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
|
||||
url_input = gr.Textbox(
|
||||
label=get_lang_text("url_input_label"),
|
||||
lines=1,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
def show_url_input(evt: gr.SelectData):
|
||||
if evt.value == url_knowledge_dialogue:
|
||||
@ -559,10 +579,10 @@ def build_single_model_ui():
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title=get_lang_text("database_smart_assistant"),
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title=get_lang_text("database_smart_assistant"),
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
|
@ -18,7 +18,14 @@ CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
local_persist=True,
|
||||
file_type="default",
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
@ -63,7 +70,6 @@ class KnowledgeEmbedding:
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
|
||||
elif self.file_type == "default":
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
@ -71,7 +77,6 @@ class KnowledgeEmbedding:
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
|
@ -13,6 +13,7 @@ from pilot.summary.mysql_db_summary import MysqlSummary
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
|
||||
CFG = Config()
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
|
||||
class DBSummaryClient:
|
||||
@ -88,13 +89,18 @@ class DBSummaryClient:
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs]
|
||||
related_tables = [
|
||||
json.loads(table_doc.page_content)["table_name"]
|
||||
for table_doc in table_docs
|
||||
]
|
||||
else:
|
||||
table_docs = knowledge_embedding_client.similar_search(query, 1)
|
||||
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
|
||||
# query, table_docs[0].page_content
|
||||
# )
|
||||
related_tables = _get_llm_response(query, dbname, table_docs[0].page_content)
|
||||
related_tables = _get_llm_response(
|
||||
query, dbname, table_docs[0].page_content
|
||||
)
|
||||
related_table_summaries = []
|
||||
for table in related_tables:
|
||||
vector_store_config = {
|
||||
@ -118,35 +124,14 @@ def _get_llm_response(query, db_input, dbsummary):
|
||||
"max_new_tokens": 512,
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"user_input": query,
|
||||
"db_input": db_input,
|
||||
"db_select": db_input,
|
||||
"db_summary": dbsummary,
|
||||
}
|
||||
chat_factory = ChatFactory()
|
||||
chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param)
|
||||
|
||||
return chat.call()
|
||||
# payload = {
|
||||
# "model": CFG.LLM_MODEL,
|
||||
# "prompt": prompt,
|
||||
# "temperature": float(0.7),
|
||||
# "max_new_tokens": int(512),
|
||||
# "stop": state.sep
|
||||
# if state.sep_style == SeparatorStyle.SINGLE
|
||||
# else state.sep2,
|
||||
# }
|
||||
# headers = {"User-Agent": "dbgpt Client"}
|
||||
# response = requests.post(
|
||||
# urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
# headers=headers,
|
||||
# json=payload,
|
||||
# timeout=120,
|
||||
# )
|
||||
#
|
||||
# print(related_tables)
|
||||
# return related_tables
|
||||
# except NotCommands as e:
|
||||
# print("llm response error:" + e.message)
|
||||
|
||||
chat: BaseChat = chat_factory.get_implementation(
|
||||
ChatScene.InnerChatDBSummary.value, **chat_param
|
||||
)
|
||||
res = chat.nostream_call()
|
||||
return json.loads(res)["table"]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
@ -37,20 +37,23 @@ class MysqlSummary(DBSummary):
|
||||
table_name=table_comment[0], table_comment=table_comment[1]
|
||||
)
|
||||
)
|
||||
|
||||
vector_table = json.dumps(
|
||||
{"table_name": table_comment[0], "table_description": table_comment[1]}
|
||||
)
|
||||
self.vector_tables_info.append(
|
||||
vector_table.encode("utf-8").decode("unicode_escape")
|
||||
)
|
||||
|
||||
self.table_columns_info = []
|
||||
for table_name in tables:
|
||||
table_summary = MysqlTableSummary(self.db, name, table_name)
|
||||
self.tables[table_name] = table_summary.get_summery()
|
||||
# self.tables[table_name] = table_summary.get_summery()
|
||||
self.tables[table_name] = table_summary.get_columns()
|
||||
self.table_columns_info.append(table_summary.get_columns())
|
||||
# self.tables_info.append(table_summary.get_summery())
|
||||
|
||||
def get_summery(self):
|
||||
if CFG.SUMMARY_CONFIG == "VECTOR":
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
return self.vector_tables_info
|
||||
else:
|
||||
return self.summery.format(
|
||||
@ -63,6 +66,9 @@ class MysqlSummary(DBSummary):
|
||||
def get_table_comments(self):
|
||||
return self.table_comments
|
||||
|
||||
def get_columns(self):
|
||||
return self.table_columns_info
|
||||
|
||||
|
||||
class MysqlTableSummary(TableSummary):
|
||||
"""Get mysql table summary template."""
|
||||
@ -78,10 +84,16 @@ class MysqlTableSummary(TableSummary):
|
||||
self.db = instance
|
||||
fields = self.db.get_fields(name)
|
||||
indexes = self.db.get_indexes(name)
|
||||
field_names = []
|
||||
for field in fields:
|
||||
field_summary = MysqlFieldsSummary(field)
|
||||
self.fields.append(field_summary)
|
||||
self.fields_info.append(field_summary.get_summery())
|
||||
field_names.append(field[0])
|
||||
|
||||
self.column_summery = """{name}({columns_info})""".format(
|
||||
name=name, columns_info=",".join(field_names)
|
||||
)
|
||||
|
||||
for index in indexes:
|
||||
index_summary = MysqlIndexSummary(index)
|
||||
@ -96,6 +108,9 @@ class MysqlTableSummary(TableSummary):
|
||||
indexes=";".join(self.indexes_info),
|
||||
)
|
||||
|
||||
def get_columns(self):
|
||||
return self.column_summery
|
||||
|
||||
|
||||
class MysqlFieldsSummary(FieldSummary):
|
||||
"""Get mysql field summary template."""
|
||||
|
@ -1,4 +1,5 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
|
||||
# from pilot.vector_store.milvus_store import MilvusStore
|
||||
|
||||
connector = {"Chroma": ChromaStore, "Milvus": None}
|
||||
|
Loading…
Reference in New Issue
Block a user