update:db chat

This commit is contained in:
aries-ckt 2023-06-13 17:12:20 +08:00
parent 5d440dce94
commit 7aa59596d2
9 changed files with 176 additions and 29 deletions

View File

@ -181,6 +181,13 @@ In the .env configuration file, modify the LANGUAGE parameter to switch between
We currently support many document formats: txt, pdf, md, html, doc, ppt, and url. We currently support many document formats: txt, pdf, md, html, doc, ppt, and url.
before execution:
```
python -m spacy download zh_core_web_sm
```
2.set .env configuration set your vector store type, eg:VECTOR_STORE_TYPE=Chroma, now we support Chroma and Milvus(version > 2.1) 2.set .env configuration set your vector store type, eg:VECTOR_STORE_TYPE=Chroma, now we support Chroma and Milvus(version > 2.1)
3.Run the knowledge repository script in the tools directory. 3.Run the knowledge repository script in the tools directory.

View File

@ -178,6 +178,13 @@ $ python webserver.py
当前支持的文档格式: txt, pdf, md, html, doc, ppt, and url. 当前支持的文档格式: txt, pdf, md, html, doc, ppt, and url.
在操作之前先执行
```
python -m spacy download zh_core_web_sm
```
2.在.env文件指定你的向量数据库类型,VECTOR_STORE_TYPE(默认Chroma),目前支持Chroma,Milvus(需要设置MILVUS_URL和MILVUS_PORT) 2.在.env文件指定你的向量数据库类型,VECTOR_STORE_TYPE(默认Chroma),目前支持Chroma,Milvus(需要设置MILVUS_URL和MILVUS_PORT)
注意Milvus版本需要>2.1 注意Milvus版本需要>2.1

View File

@ -12,6 +12,12 @@ As the knowledge base is currently the most significant user demand scenario, we
We currently support many document formats: txt, pdf, md, html, doc, ppt, and url. We currently support many document formats: txt, pdf, md, html, doc, ppt, and url.
before execution:
```
python -m spacy download zh_core_web_sm
```
2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma 2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma
(now only support Chroma and Milvus, if you set Milvus, please set MILVUS_URL and MILVUS_PORT) (now only support Chroma and Milvus, if you set Milvus, please set MILVUS_URL and MILVUS_PORT)

View File

@ -443,6 +443,14 @@ class Database:
indexes = cursor.fetchall() indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes] return [(index[2], index[4]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name): def get_fields(self, table_name):
"""Get column fields about specified table.""" """Get column fields about specified table."""
session = self._db_sessions() session = self._db_sessions()

View File

@ -7,7 +7,7 @@ lang_dicts = {
"learn_more_markdown": "该服务是仅供非商业用途的研究预览。受 Vicuna-13B 模型 [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) 的约束", "learn_more_markdown": "该服务是仅供非商业用途的研究预览。受 Vicuna-13B 模型 [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) 的约束",
"model_control_param": "模型参数", "model_control_param": "模型参数",
"sql_generate_mode_direct": "直接执行结果", "sql_generate_mode_direct": "直接执行结果",
"sql_generate_mode_none": "不直接执行结果", "sql_generate_mode_none": "db问答",
"max_input_token_size": "最大输出Token数", "max_input_token_size": "最大输出Token数",
"please_choose_database": "请选择数据", "please_choose_database": "请选择数据",
"sql_generate_diagnostics": "SQL生成与诊断", "sql_generate_diagnostics": "SQL生成与诊断",
@ -44,7 +44,7 @@ lang_dicts = {
"learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B", "learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B",
"model_control_param": "Model Parameters", "model_control_param": "Model Parameters",
"sql_generate_mode_direct": "Execute directly", "sql_generate_mode_direct": "Execute directly",
"sql_generate_mode_none": "Execute without mode", "sql_generate_mode_none": "chat to db",
"max_input_token_size": "Maximum output token size", "max_input_token_size": "Maximum output token size",
"please_choose_database": "Please choose database", "please_choose_database": "Please choose database",
"sql_generate_diagnostics": "SQL Generation & Diagnostics", "sql_generate_diagnostics": "SQL Generation & Diagnostics",

View File

@ -52,7 +52,7 @@ class ChatWithDbQA(BaseChat):
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
if self.db_name: if self.db_name:
client = DBSummaryClient() client = DBSummaryClient()
table_info = client.get_similar_tables( table_info = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k dbname=self.db_name, query=self.current_user_input, topk=self.top_k
) )
# table_info = self.database.table_simple_info(self.db_connect) # table_info = self.database.table_simple_info(self.db_connect)
@ -60,8 +60,8 @@ class ChatWithDbQA(BaseChat):
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), # "top_k": str(self.top_k),
"dialect": dialect, # "dialect": dialect,
"table_info": table_info, "table_info": table_info,
} }
return input_values return input_values

View File

@ -10,22 +10,44 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: # PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
# {table_info}
#
# Question: {input}
#
# """
# _DEFAULT_TEMPLATE = """
# You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
# Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
# You can order the results by a relevant column to return the most interesting examples in the database.
# Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
# Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
#
# """
_DEFAULT_TEMPLATE_EN = """
You are a database expert. you will be given metadata information about a database or table, and then provide a brief summary and answer to the question. For example, question: "How many tables are there in database 'db_gpt'?" , answer: "There are 5 tables in database 'db_gpt', which are 'book', 'book_category', 'borrower', 'borrowing', and 'category'.
Based on the database metadata information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
database metadata information:
{table_info} {table_info}
question:
Question: {input} {input}
""" """
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE_ZH = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. 你是一位数据库专家你将获得有关数据库或表的元数据信息然后提供简要的总结和回答例如问题数据库 'db_gpt' 中有多少个表 答案数据库 'db_gpt' 中有 5 个表分别是 'book''book_category''borrower''borrowing' 'category'
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. 根据以下数据库元数据信息为用户提供专业简洁的答案如果无法从提供的内容中获取答案请说知识库中提供的信息不足以回答此问题 禁止随意捏造信息
You can order the results by a relevant column to return the most interesting examples in the database. 数据库元数据信息:
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. {table_info}
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. 问题:
{input}
""" """
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
@ -33,10 +55,10 @@ PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbQA.value, template_scene=ChatScene.ChatWithDbQA.value,
input_variables=["input", "table_info", "dialect", "top_k"], input_variables=["input", "table_info"],
response_format=None, response_format=None,
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser( output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT

View File

@ -32,13 +32,14 @@ class DBSummaryClient:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
) )
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_summary",
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=db_summary_client.get_summery(), file_path=db_summary_client.get_summery(),
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist(): if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery(): for vector_table_info in db_summary_client.get_summery():
@ -69,10 +70,22 @@ class DBSummaryClient:
logger.info("db summary embedding success") logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk):
vector_store_config = {
"vector_store_name": dbname + "_profile",
}
knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_docs =knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
return ans
def get_similar_tables(self, dbname, query, topk): def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info""" """get user query related tables info"""
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_summary",
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
@ -112,6 +125,29 @@ class DBSummaryClient:
for dbname in dbs: for dbname in dbs:
self.db_summary_embedding(dbname) self.db_summary_embedding(dbname)
def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = {
"vector_store_name": dbname + "_profile",
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summery(),
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
embedding = StringEmbedding(
table_summary,
profile_store_config,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary): def _get_llm_response(query, db_input, dbsummary):
chat_param = { chat_param = {

View File

@ -5,6 +5,43 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde
CFG = Config() CFG = Config()
# {
# "database_name": "mydatabase",
# "tables": [
# {
# "table_name": "customers",
# "columns": [
# {"name": "id", "type": "int(11)", "is_primary_key": true},
# {"name": "name", "type": "varchar(255)", "is_primary_key": false},
# {"name": "email", "type": "varchar(255)", "is_primary_key": false}
# ],
# "indexes": [
# {"name": "PRIMARY", "type": "primary", "columns": ["id"]},
# {"name": "idx_name", "type": "index", "columns": ["name"]},
# {"name": "idx_email", "type": "index", "columns": ["email"]}
# ],
# "size_in_bytes": 1024,
# "rows": 1000
# },
# {
# "table_name": "orders",
# "columns": [
# {"name": "id", "type": "int(11)", "is_primary_key": true},
# {"name": "customer_id", "type": "int(11)", "is_primary_key": false},
# {"name": "order_date", "type": "date", "is_primary_key": false},
# {"name": "total_amount", "type": "decimal(10,2)", "is_primary_key": false}
# ],
# "indexes": [
# {"name": "PRIMARY", "type": "primary", "columns": ["id"]},
# {"name": "fk_customer_id", "type": "foreign_key", "columns": ["customer_id"], "referenced_table": "customers", "referenced_columns": ["id"]}
# ],
# "size_in_bytes": 2048,
# "rows": 500
# }
# ],
# "qps": 100,
# "tps": 50
# }
class MysqlSummary(DBSummary): class MysqlSummary(DBSummary):
"""Get mysql summary template.""" """Get mysql summary template."""
@ -13,7 +50,7 @@ class MysqlSummary(DBSummary):
self.name = name self.name = name
self.type = "MYSQL" self.type = "MYSQL"
self.summery = ( self.summery = (
"""database name:{name}, database type:{type}, table infos:{table_info}""" """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
) )
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
@ -31,12 +68,14 @@ class MysqlSummary(DBSummary):
) )
tables = self.db.get_table_names() tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name) self.table_comments = self.db.get_table_comments(name)
comment_map = {}
for table_comment in self.table_comments: for table_comment in self.table_comments:
self.tables_info.append( self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format( "table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1] table_name=table_comment[0], table_comment=table_comment[1]
) )
) )
comment_map[table_comment[0]] = table_comment[1]
vector_table = json.dumps( vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]} {"table_name": table_comment[0], "table_description": table_comment[1]}
@ -45,11 +84,18 @@ class MysqlSummary(DBSummary):
vector_table.encode("utf-8").decode("unicode_escape") vector_table.encode("utf-8").decode("unicode_escape")
) )
self.table_columns_info = [] self.table_columns_info = []
self.table_columns_json = []
for table_name in tables: for table_name in tables:
table_summary = MysqlTableSummary(self.db, name, table_name) table_summary = MysqlTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summery() # self.tables[table_name] = table_summary.get_summery()
self.tables[table_name] = table_summary.get_columns() self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns()) self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format(
table_name=table_name, table_comment=self.db.get_show_create_table(table_name)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summery())
def get_summery(self): def get_summery(self):
@ -60,23 +106,29 @@ class MysqlSummary(DBSummary):
name=self.name, type=self.type, table_info=";".join(self.tables_info) name=self.name, type=self.type, table_info=";".join(self.tables_info)
) )
def get_db_summery(self):
return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000
)
def get_table_summary(self): def get_table_summary(self):
return self.tables return self.tables
def get_table_comments(self): def get_table_comments(self):
return self.table_comments return self.table_comments
def get_columns(self): def table_info_json(self):
return self.table_columns_info return self.table_columns_json
class MysqlTableSummary(TableSummary): class MysqlTableSummary(TableSummary):
"""Get mysql table summary template.""" """Get mysql table summary template."""
def __init__(self, instance, dbname, name): def __init__(self, instance, dbname, name, comment_map):
self.name = name self.name = name
self.dbname = dbname self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = [] self.fields = []
self.fields_info = [] self.fields_info = []
self.indexes = [] self.indexes = []
@ -100,6 +152,10 @@ class MysqlTableSummary(TableSummary):
self.indexes.append(index_summary) self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery()) self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000
)
def get_summery(self): def get_summery(self):
return self.summery.format( return self.summery.format(
name=self.name, name=self.name,
@ -111,20 +167,24 @@ class MysqlTableSummary(TableSummary):
def get_columns(self): def get_columns(self):
return self.column_summery return self.column_summery
def get_summary_json(self):
return self.json_summery
class MysqlFieldsSummary(FieldSummary): class MysqlFieldsSummary(FieldSummary):
"""Get mysql field summary template.""" """Get mysql field summary template."""
def __init__(self, field): def __init__(self, field):
self.name = field[0] self.name = field[0]
self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ # self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1] self.data_type = field[1]
self.default_value = field[2] self.default_value = field[2]
self.is_nullable = field[3] self.is_nullable = field[3]
self.comment = field[4] self.comment = field[4]
def get_summery(self): def get_summery(self):
return self.summery.format( return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name, name=self.name,
data_type=self.data_type, data_type=self.data_type,
is_nullable=self.is_nullable, is_nullable=self.is_nullable,
@ -138,11 +198,12 @@ class MysqlIndexSummary(IndexSummary):
def __init__(self, index): def __init__(self, index):
self.name = index[0] self.name = index[0]
self.summery = """index name:{name}, index bind columns:{bind_fields}""" # self.summery = """index name:{name}, index bind columns:{bind_fields}"""
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1] self.bind_fields = index[1]
def get_summery(self): def get_summery(self):
return self.summery.format(name=self.name, bind_fields=self.bind_fields) return self.summery_template.format(name=self.name, bind_fields=self.bind_fields)
if __name__ == "__main__": if __name__ == "__main__":