Merge branch 'summary' into dev

# Conflicts:
#	pilot/model/llm_out/guanaco_llm.py
#	pilot/out_parser/base.py
#	pilot/scene/base_chat.py
This commit is contained in:
yhjun1026 2023-06-01 19:15:00 +08:00
commit 7e6cf9c9e0
43 changed files with 308 additions and 276 deletions

View File

@ -55,8 +55,6 @@ def fix_and_parse_json(
logger.error("参数解析错误", e)
def correct_json(json_to_load: str) -> str:
"""
Correct common JSON errors.

View File

@ -1,6 +1,7 @@
from collections import OrderedDict
from collections import deque
class FixedSizeDict(OrderedDict):
def __init__(self, max_size):
super().__init__()
@ -11,6 +12,7 @@ class FixedSizeDict(OrderedDict):
self.popitem(last=False)
super().__setitem__(key, value)
class FixedSizeList:
def __init__(self, max_size):
self.max_size = max_size
@ -29,4 +31,4 @@ class FixedSizeList:
return len(self.list)
def __str__(self):
return str(list(self.list))
return str(list(self.list))

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import sqlparse
import regex as re
import regex as re
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
@ -283,16 +283,16 @@ class Database:
return f"Error: {e}"
def __write(self, session, write_sql):
print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
#TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self,session, query, fetch: str = "all"):
def __query(self, session, query, fetch: str = "all"):
"""
only for query
Args:
@ -390,37 +390,44 @@ class Database:
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == 'insert':
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
if cmd_type == "insert":
match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
)
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(',')
values = values.split(',')
columns = columns.split(",")
values = values.split(",")
# 构造 WHERE 子句
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
return f'SELECT * FROM {table_name} WHERE {where_clause}'
where_clause = " AND ".join(
[
f"{col.strip()}={val.strip()}"
for col, val in zip(columns, values)
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
elif cmd_type == 'delete':
elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f'SELECT * FROM {table_name}'
return f"SELECT * FROM {table_name}"
elif cmd_type == 'update':
elif cmd_type == "update":
table_name = parts[1]
set_idx = parts.index('set')
where_idx = parts.index('where')
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
where_clause = ' '.join(parts[where_idx + 1:])
where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
@ -429,8 +436,6 @@ class Database:
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()

View File

@ -103,8 +103,12 @@ class Config(metaclass=Singleton):
else:
self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True"
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True"
self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
)
self.NATIVE_SQL_CAN_RUN_WRITE = (
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
)
### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")

View File

@ -11,11 +11,9 @@ class BaseConnect(BaseModel, ABC):
type
driver: str
def get_session(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
pass
@ -32,4 +30,4 @@ class BaseConnect(BaseModel, ABC):
pass
def run(self, session, command: str, fetch: str = "all") -> List:
pass
pass

View File

@ -11,8 +11,7 @@ class MySQLConnect(RDBMSDatabase):
Usage:
"""
type:str = "MySQL"
type: str = "MySQL"
connect_url = "mysql+pymysql://"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -106,7 +106,7 @@ class Conversation:
conv_default = Conversation(
system = None,
system=None,
roles=("human", "ai"),
messages=[],
offset=0,
@ -298,7 +298,6 @@ chat_mode_title = {
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
"chat_use_plugin": get_lang_text("chat_use_plugin"),
"knowledge_qa": get_lang_text("knowledge_qa"),
}
conversation_sql_mode = {

View File

@ -29,7 +29,6 @@ lang_dicts = {
"url_input_label": "输入网页地址",
"add_as_new_klg": "添加为新知识库",
"add_file_to_klg": "向知识库中添加文件",
"upload_file": "上传文件",
"add_file": "添加文件",
"upload_and_load_to_klg": "上传并加载到知识库",

View File

@ -9,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@torch.inference_mode()
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api"""
prompt = params["prompt"]
@ -57,7 +57,7 @@ def chatglm_generate_stream(
# i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
):
if echo:
output = query + " " + response

View File

@ -83,9 +83,7 @@ class BaseOutputParser(ABC):
output = self.__post_process_code(output)
yield output
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
output = data["text"] + f" (error_code: {data['error_code']})"
yield output
def parse_model_nostream_resp(self, response, sep: str):

View File

@ -133,10 +133,8 @@ class PluginPromptGenerator:
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_commands_string(self)->str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
def generate_commands_string(self) -> str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
def generate_prompt_string(self) -> str:
"""

View File

@ -33,13 +33,13 @@ class PromptTemplate(BaseModel, ABC):
"""A list of the names of the variables the prompt template expects."""
template_scene: Optional[str]
template_define: Optional[str]
template_define: Optional[str]
"""this template define"""
template: Optional[str]
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str]
response_format: Optional[str]
"""default use stream out"""
stream_out: bool = True
""""""
@ -62,7 +62,9 @@ class PromptTemplate(BaseModel, ABC):
if self.template:
if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template, **kwargs
)
def add_goals(self, goal: str) -> None:
self.goals.append(goal)

View File

@ -129,7 +129,7 @@ class BaseChat(ABC):
def stream_call(self):
payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
@ -141,7 +141,7 @@ class BaseChat(ABC):
stream=True,
timeout=120,
)
return response;
return response
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)

View File

@ -21,15 +21,21 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbExecute,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDbExecute.value} mode should chose db!")
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
@ -40,9 +46,7 @@ class ChatWithDbAutoExecute(BaseChat):
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError(
"Could not import DBSummaryClient. "
)
raise ValueError("Could not import DBSummaryClient. ")
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),

View File

@ -20,9 +20,8 @@ class DbChatOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text);
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
sql, thoughts = response["sql"], response["thoughts"]

View File

@ -62,4 +62,3 @@ prompt = PromptTemplate(
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -19,13 +19,17 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, db_name, user_input):
def __init__(
self, temperature, max_new_tokens, chat_session_id, db_name, user_input
):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.db_name = db_name
if db_name:
self.database = CFG.local_db
@ -34,17 +38,16 @@ class ChatWithDbQA(BaseChat):
self.top_k: int = 5
def generate_input_values(self):
table_info = ""
dialect = "mysql"
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError(
"Could not import DBSummaryClient. "
)
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
table_info = DBSummaryClient.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
# table_info = self.database.table_simple_info(self.db_connect)
dialect = self.database.dialect
@ -52,7 +55,7 @@ class ChatWithDbQA(BaseChat):
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": dialect,
"table_info": table_info
"table_info": table_info,
}
return input_values

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -27,7 +27,6 @@ Pay attention to use only the column names that you can see in the schema descri
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
@ -37,7 +36,7 @@ prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX ,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
@ -45,4 +44,3 @@ prompt = PromptTemplate(
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -14,56 +14,72 @@ from pilot.scene.chat_execution.prompt import prompt
CFG = Config()
class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatExecution.value
plugins_prompt_generator:PluginPromptGenerator
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, plugin_selector:str=None):
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input)
def __init__(
self,
temperature,
max_new_tokens,
chat_session_id,
user_input,
plugin_selector: str = None,
):
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令
self.select_plugin = plugin_selector
if self.select_plugin:
for plugin in CFG.plugins:
if plugin._name == plugin_selector :
if plugin._name == plugin_selector:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
else:
for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
"constraints": self.__list_to_prompt_str(
list(self.plugins_prompt_generator.constraints)
),
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
}
return input_values
def do_with_prompt_response(self, prompt_response):
## plugin command run
return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator)
return execute_command(
str(prompt_response.command.get("name")),
prompt_response.command.get("args", {}),
self.plugins_prompt_generator,
)
def chat_show(self):
super().chat_show()
def __list_to_prompt_str(self, list: List) -> str:
if list:
separator = '\n'
separator = "\n"
return separator.join(list)
else:
return ""

View File

@ -10,14 +10,13 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple):
command: Dict
thoughts: Dict
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts = response["command"], response["thoughts"]
@ -25,7 +24,7 @@ class PluginChatOutputParser(BaseOutputParser):
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}" )
print(f"parse_view_response:{speak},{str(data)}")
view_text = f"##### {speak}" + "\n" + str(data)
return view_text

View File

@ -53,7 +53,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"],
input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
@ -63,4 +63,4 @@ prompt = PromptTemplate(
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -11,6 +11,7 @@ from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
class ChatFactory(metaclass=Singleton):
@staticmethod
def get_implementation(chat_mode, **kwargs):

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -24,18 +23,22 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatNewKnowledge (BaseChat):
class ChatNewKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatNewKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, knowledge_name):
def __init__(
self, temperature, max_new_tokens, chat_session_id, user_input, knowledge_name
):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNewKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.knowledge_name = knowledge_name
vector_store_config = {
"vector_store_name": knowledge_name,
@ -49,21 +52,17 @@ class ChatNewKnowledge (BaseChat):
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
input_values = {"context": docs, "question": self.current_user_input}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -23,7 +23,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
@ -42,5 +41,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -24,43 +23,42 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatDefaultKnowledge (BaseChat):
class ChatDefaultKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
input_values = {"context": docs, "question": self.current_user_input}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
@ -39,5 +38,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@ -8,34 +7,42 @@ from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
CFG = Config()
class InnerChatDBSummary (BaseChat):
class InnerChatDBSummary(BaseChat):
chat_scene: str = ChatScene.InnerChatDBSummary.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary):
def __init__(
self,
temperature,
max_new_tokens,
chat_session_id,
user_input,
db_select,
db_summary,
):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.db_name = db_select
self.db_summary = db_summary
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.db_input = db_select
self.db_summary = db_summary
def generate_input_values(self):
input_values = {
"db_input": self.db_name,
"db_profile_summary": self.db_summary
"db_input": self.db_input,
"db_profile_summary": self.db_summary,
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@ -10,13 +10,15 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
return clean_str
def parse_view_response(self, ai_text, data) -> str:
return ai_text["table"]
return ai_text
def get_format_instructions(self) -> str:
pass

View File

@ -7,33 +7,30 @@ from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import (
NormalChatOutputParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE =""""""
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """
Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary}
Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
The response format must be JSON, and the key of JSON must be "table".
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
The response format must be JSON, and the key of JSON must be "table".
"""
RESPONSE_FORMAT = {
"table": ["orders", "products"]
}
RESPONSE_FORMAT = {"table": ["orders", "products"]}
PROMPT_SEP = SeparatorStyle.SINGLE.value
@ -54,5 +51,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -24,18 +23,20 @@ from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
class ChatUrlKnowledge (BaseChat):
class ChatUrlKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatUrlKnowledge.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, url):
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input, url):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatUrlKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatUrlKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.url = url
vector_store_config = {
"vector_store_name": url,
@ -54,19 +55,16 @@ class ChatUrlKnowledge (BaseChat):
self.knowledge_embedding_client.knowledge_embedding()
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(self.current_user_input, VECTOR_SEARCH_TOP_K)
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K
)
docs = docs[:2000]
input_values = {
"context": docs,
"question": self.current_user_input
}
input_values = {"context": docs, "question": self.current_user_input}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -20,7 +20,6 @@ _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用
"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
@ -39,5 +38,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -1,4 +1,3 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -19,25 +18,23 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input):
def __init__(self, temperature, max_new_tokens, chat_session_id, user_input):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input)
super().__init__(
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input,
)
def generate_input_values(self):
input_values = {
"input": self.current_user_input
}
input_values = {"input": self.current_user_input}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.ChatNormal.value

View File

@ -10,8 +10,8 @@ from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text

View File

@ -29,5 +29,3 @@ prompt = PromptTemplate(
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -85,9 +85,7 @@ add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
url_knowledge_dialogue = get_lang_text(
"knowledge_qa_type_url_knowledge_dialogue"
)
url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
knowledge_qa_type_list = [
llm_native_dialogue,
@ -205,9 +203,9 @@ def post_process_code(code):
def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title['chat_use_plugin'] == selected:
if chat_mode_title["chat_use_plugin"] == selected:
return ChatScene.ChatExecution
elif chat_mode_title['knowledge_qa'] == selected:
elif chat_mode_title["knowledge_qa"] == selected:
mode = param
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
@ -232,14 +230,23 @@ def chatbot_callback(state, message):
def http_bot(
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input,
knowledge_name
state,
selected,
temperature,
max_new_tokens,
plugin_selector,
mode,
sql_mode,
db_selector,
url_input,
knowledge_name,
):
logger.info(
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
if chat_mode_title['knowledge_qa'] == selected:
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
)
if chat_mode_title["knowledge_qa"] == selected:
scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected:
elif chat_mode_title["chat_use_plugin"] == selected:
scene: ChatScene = get_chat_mode(selected)
else:
scene: ChatScene = get_chat_mode(selected, sql_mode)
@ -251,7 +258,7 @@ def http_bot(
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
"user_input": state.last_user_input,
}
elif ChatScene.ChatWithDbQA == scene:
chat_param = {
@ -289,7 +296,7 @@ def http_bot(
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"knowledge_name": knowledge_name
"knowledge_name": knowledge_name,
}
elif ChatScene.ChatUrlKnowledge == scene:
chat_param = {
@ -297,7 +304,7 @@ def http_bot(
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"url": url_input
"url": url_input,
}
else:
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
@ -314,7 +321,11 @@ def http_bot(
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
state.messages[-1][-1] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk,chat.skip_echo_len)
state.messages[-1][
-1
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except Exception as e:
print(traceback.format_exc())
@ -323,8 +334,8 @@ def http_bot(
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -361,7 +372,7 @@ def build_single_model_ui():
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False
get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
@ -411,7 +422,7 @@ def build_single_model_ui():
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none")
value=get_lang_text("sql_generate_mode_none"),
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
@ -428,15 +439,19 @@ def build_single_model_ui():
value="",
interactive=True,
show_label=True,
type="value"
type="value",
).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
def plugin_change(
evt: gr.SelectData,
): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
print(f"user plugin:{plugins_select_info().get(evt.value)}")
return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selected = gr.Textbox(
show_label=False, visible=False, placeholder="Selected"
)
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
@ -456,7 +471,12 @@ def build_single_model_ui():
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
url_input = gr.Textbox(
label=get_lang_text("url_input_label"),
lines=1,
interactive=True,
visible=False,
)
def show_url_input(evt: gr.SelectData):
if evt.value == url_knowledge_dialogue:
@ -559,10 +579,10 @@ def build_single_model_ui():
def build_webdemo():
with gr.Blocks(
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(

View File

@ -18,7 +18,14 @@ CFG = Config()
class KnowledgeEmbedding:
def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"):
def __init__(
self,
file_path,
model_name,
vector_store_config,
local_persist=True,
file_type="default",
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
@ -63,7 +70,6 @@ class KnowledgeEmbedding:
vector_store_config=self.vector_store_config,
)
elif self.file_type == "default":
embedding = MarkdownEmbedding(
file_path=self.file_path,
@ -71,7 +77,6 @@ class KnowledgeEmbedding:
vector_store_config=self.vector_store_config,
)
return embedding
def similar_search(self, text, topk):

View File

@ -13,6 +13,7 @@ from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory
CFG = Config()
chat_factory = ChatFactory()
class DBSummaryClient:
@ -88,13 +89,18 @@ class DBSummaryClient:
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs]
related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(query, dbname, table_docs[0].page_content)
related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
@ -118,35 +124,14 @@ def _get_llm_response(query, db_input, dbsummary):
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_input": db_input,
"db_select": db_input,
"db_summary": dbsummary,
}
chat_factory = ChatFactory()
chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param)
return chat.call()
# payload = {
# "model": CFG.LLM_MODEL,
# "prompt": prompt,
# "temperature": float(0.7),
# "max_new_tokens": int(512),
# "stop": state.sep
# if state.sep_style == SeparatorStyle.SINGLE
# else state.sep2,
# }
# headers = {"User-Agent": "dbgpt Client"}
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate"),
# headers=headers,
# json=payload,
# timeout=120,
# )
#
# print(related_tables)
# return related_tables
# except NotCommands as e:
# print("llm response error:" + e.message)
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat.nostream_call()
return json.loads(res)["table"]
# if __name__ == "__main__":

View File

@ -37,20 +37,23 @@ class MysqlSummary(DBSummary):
table_name=table_comment[0], table_comment=table_comment[1]
)
)
vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]}
)
self.vector_tables_info.append(
vector_table.encode("utf-8").decode("unicode_escape")
)
self.table_columns_info = []
for table_name in tables:
table_summary = MysqlTableSummary(self.db, name, table_name)
self.tables[table_name] = table_summary.get_summery()
# self.tables[table_name] = table_summary.get_summery()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.tables_info.append(table_summary.get_summery())
def get_summery(self):
if CFG.SUMMARY_CONFIG == "VECTOR":
if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info
else:
return self.summery.format(
@ -63,6 +66,9 @@ class MysqlSummary(DBSummary):
def get_table_comments(self):
return self.table_comments
def get_columns(self):
return self.table_columns_info
class MysqlTableSummary(TableSummary):
"""Get mysql table summary template."""
@ -78,10 +84,16 @@ class MysqlTableSummary(TableSummary):
self.db = instance
fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name)
field_names = []
for field in fields:
field_summary = MysqlFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery())
field_names.append(field[0])
self.column_summery = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes:
index_summary = MysqlIndexSummary(index)
@ -96,6 +108,9 @@ class MysqlTableSummary(TableSummary):
indexes=";".join(self.indexes_info),
)
def get_columns(self):
return self.column_summery
class MysqlFieldsSummary(FieldSummary):
"""Get mysql field summary template."""

View File

@ -1,4 +1,5 @@
from pilot.vector_store.chroma_store import ChromaStore
# from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": None}