diff --git a/README.md b/README.md index 3bc57bef8..46c433528 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,13 @@ In the .env configuration file, modify the LANGUAGE parameter to switch between We currently support many document formats: txt, pdf, md, html, doc, ppt, and url. +before execution: + +``` +python -m spacy download zh_core_web_sm + +``` + 2.set .env configuration set your vector store type, eg:VECTOR_STORE_TYPE=Chroma, now we support Chroma and Milvus(version > 2.1) 3.Run the knowledge repository script in the tools directory. diff --git a/README.zh.md b/README.zh.md index fdc72a54b..cbfb67ac0 100644 --- a/README.zh.md +++ b/README.zh.md @@ -178,6 +178,13 @@ $ python webserver.py 当前支持的文档格式: txt, pdf, md, html, doc, ppt, and url. +在操作之前先执行 + +``` +python -m spacy download zh_core_web_sm + +``` + 2.在.env文件指定你的向量数据库类型,VECTOR_STORE_TYPE(默认Chroma),目前支持Chroma,Milvus(需要设置MILVUS_URL和MILVUS_PORT) 注意Milvus版本需要>2.1 diff --git a/docs/modules/knownledge.md b/docs/modules/knownledge.md index 8cf9cd4c3..c108920b2 100644 --- a/docs/modules/knownledge.md +++ b/docs/modules/knownledge.md @@ -12,6 +12,12 @@ As the knowledge base is currently the most significant user demand scenario, we We currently support many document formats: txt, pdf, md, html, doc, ppt, and url. +before execution: + +``` +python -m spacy download zh_core_web_sm + +``` 2.Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus, if you set Milvus, please set MILVUS_URL and MILVUS_PORT) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index bc3aa8340..f7dbd7164 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -443,6 +443,14 @@ class Database: indexes = cursor.fetchall() return [(index[2], index[4]) for index in indexes] + def get_show_create_table(self, table_name): + """Get table show create table about specified table.""" + session = self._db_sessions() + cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}")) + ans = cursor.fetchall() + return ans[0][1] + + def get_fields(self, table_name): """Get column fields about specified table.""" session = self._db_sessions() diff --git a/pilot/language/lang_content_mapping.py b/pilot/language/lang_content_mapping.py index afcfaeaba..e2ea8b4cc 100644 --- a/pilot/language/lang_content_mapping.py +++ b/pilot/language/lang_content_mapping.py @@ -7,7 +7,7 @@ lang_dicts = { "learn_more_markdown": "该服务是仅供非商业用途的研究预览。受 Vicuna-13B 模型 [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) 的约束", "model_control_param": "模型参数", "sql_generate_mode_direct": "直接执行结果", - "sql_generate_mode_none": "不直接执行结果", + "sql_generate_mode_none": "db问答", "max_input_token_size": "最大输出Token数", "please_choose_database": "请选择数据", "sql_generate_diagnostics": "SQL生成与诊断", @@ -44,7 +44,7 @@ lang_dicts = { "learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B", "model_control_param": "Model Parameters", "sql_generate_mode_direct": "Execute directly", - "sql_generate_mode_none": "Execute without mode", + "sql_generate_mode_none": "chat to db", "max_input_token_size": "Maximum output token size", "please_choose_database": "Please choose database", "sql_generate_diagnostics": "SQL Generation & Diagnostics", diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index cb2425ea9..e956bdc8b 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -52,7 +52,7 @@ class ChatWithDbQA(BaseChat): raise ValueError("Could not import DBSummaryClient. ") if self.db_name: client = DBSummaryClient() - table_info = client.get_similar_tables( + table_info = client.get_db_summary( dbname=self.db_name, query=self.current_user_input, topk=self.top_k ) # table_info = self.database.table_simple_info(self.db_connect) @@ -60,8 +60,8 @@ class ChatWithDbQA(BaseChat): input_values = { "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": dialect, + # "top_k": str(self.top_k), + # "dialect": dialect, "table_info": table_info, } return input_values diff --git a/pilot/scene/chat_db/professional_qa/prompt.py b/pilot/scene/chat_db/professional_qa/prompt.py index 9cc35b2e4..ff360cb65 100644 --- a/pilot/scene/chat_db/professional_qa/prompt.py +++ b/pilot/scene/chat_db/professional_qa/prompt.py @@ -10,22 +10,44 @@ CFG = Config() PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """ -PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: +# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info: +# {table_info} +# +# Question: {input} +# +# """ + +# _DEFAULT_TEMPLATE = """ +# You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +# Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. +# You can order the results by a relevant column to return the most interesting examples in the database. +# Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. +# Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +# +# """ + +_DEFAULT_TEMPLATE_EN = """ +You are a database expert. you will be given metadata information about a database or table, and then provide a brief summary and answer to the question. For example, question: "How many tables are there in database 'db_gpt'?" , answer: "There are 5 tables in database 'db_gpt', which are 'book', 'book_category', 'borrower', 'borrowing', and 'category'. +Based on the database metadata information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. +database metadata information: {table_info} - -Question: {input} - +question: +{input} """ -_DEFAULT_TEMPLATE = """ -You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. -Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. -Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. - +_DEFAULT_TEMPLATE_ZH = """ +你是一位数据库专家。你将获得有关数据库或表的元数据信息,然后提供简要的总结和回答。例如,问题:“数据库 'db_gpt' 中有多少个表?” 答案:“数据库 'db_gpt' 中有 5 个表,分别是 'book'、'book_category'、'borrower'、'borrowing' 和 'category'。” +根据以下数据库元数据信息,为用户提供专业简洁的答案。如果无法从提供的内容中获取答案,请说:“知识库中提供的信息不足以回答此问题。” 禁止随意捏造信息。 +数据库元数据信息: +{table_info} +问题: +{input} """ +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + PROMPT_SEP = SeparatorStyle.SINGLE.value @@ -33,10 +55,10 @@ PROMPT_NEED_NEED_STREAM_OUT = True prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbQA.value, - input_variables=["input", "table_info", "dialect", "top_k"], + input_variables=["input", "table_info"], response_format=None, template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX, + template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, output_parser=NormalChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 51f124f62..84fbf1550 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -32,13 +32,14 @@ class DBSummaryClient: model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) vector_store_config = { - "vector_store_name": dbname + "_profile", + "vector_store_name": dbname + "_summary", "embeddings": embeddings, } embedding = StringEmbedding( file_path=db_summary_client.get_summery(), vector_store_config=vector_store_config, ) + self.init_db_profile(db_summary_client, dbname, embeddings) if not embedding.vector_name_exist(): if CFG.SUMMARY_CONFIG == "FAST": for vector_table_info in db_summary_client.get_summery(): @@ -69,10 +70,22 @@ class DBSummaryClient: logger.info("db summary embedding success") + def get_db_summary(self, dbname, query, topk): + vector_store_config = { + "vector_store_name": dbname + "_profile", + } + knowledge_embedding_client = KnowledgeEmbedding( + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config=vector_store_config, + ) + table_docs =knowledge_embedding_client.similar_search(query, topk) + ans = [d.page_content for d in table_docs] + return ans + def get_similar_tables(self, dbname, query, topk): """get user query related tables info""" vector_store_config = { - "vector_store_name": dbname + "_profile", + "vector_store_name": dbname + "_summary", } knowledge_embedding_client = KnowledgeEmbedding( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], @@ -112,6 +125,29 @@ class DBSummaryClient: for dbname in dbs: self.db_summary_embedding(dbname) + def init_db_profile(self, db_summary_client, dbname, embeddings): + profile_store_config = { + "vector_store_name": dbname + "_profile", + "embeddings": embeddings, + } + embedding = StringEmbedding( + file_path=db_summary_client.get_db_summery(), + vector_store_config=profile_store_config, + ) + if not embedding.vector_name_exist(): + docs = [] + docs.extend(embedding.read_batch()) + for table_summary in db_summary_client.table_info_json(): + embedding = StringEmbedding( + table_summary, + profile_store_config, + ) + docs.extend(embedding.read_batch()) + embedding.index_to_store(docs) + logger.info("init db profile success...") + + + def _get_llm_response(query, db_input, dbsummary): chat_param = { diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py index a50b24f94..4a578fe2c 100644 --- a/pilot/summary/mysql_db_summary.py +++ b/pilot/summary/mysql_db_summary.py @@ -5,6 +5,43 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde CFG = Config() +# { +# "database_name": "mydatabase", +# "tables": [ +# { +# "table_name": "customers", +# "columns": [ +# {"name": "id", "type": "int(11)", "is_primary_key": true}, +# {"name": "name", "type": "varchar(255)", "is_primary_key": false}, +# {"name": "email", "type": "varchar(255)", "is_primary_key": false} +# ], +# "indexes": [ +# {"name": "PRIMARY", "type": "primary", "columns": ["id"]}, +# {"name": "idx_name", "type": "index", "columns": ["name"]}, +# {"name": "idx_email", "type": "index", "columns": ["email"]} +# ], +# "size_in_bytes": 1024, +# "rows": 1000 +# }, +# { +# "table_name": "orders", +# "columns": [ +# {"name": "id", "type": "int(11)", "is_primary_key": true}, +# {"name": "customer_id", "type": "int(11)", "is_primary_key": false}, +# {"name": "order_date", "type": "date", "is_primary_key": false}, +# {"name": "total_amount", "type": "decimal(10,2)", "is_primary_key": false} +# ], +# "indexes": [ +# {"name": "PRIMARY", "type": "primary", "columns": ["id"]}, +# {"name": "fk_customer_id", "type": "foreign_key", "columns": ["customer_id"], "referenced_table": "customers", "referenced_columns": ["id"]} +# ], +# "size_in_bytes": 2048, +# "rows": 500 +# } +# ], +# "qps": 100, +# "tps": 50 +# } class MysqlSummary(DBSummary): """Get mysql summary template.""" @@ -13,7 +50,7 @@ class MysqlSummary(DBSummary): self.name = name self.type = "MYSQL" self.summery = ( - """database name:{name}, database type:{type}, table infos:{table_info}""" + """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" ) self.tables = {} self.tables_info = [] @@ -31,12 +68,14 @@ class MysqlSummary(DBSummary): ) tables = self.db.get_table_names() self.table_comments = self.db.get_table_comments(name) + comment_map = {} for table_comment in self.table_comments: self.tables_info.append( "table name:{table_name},table description:{table_comment}".format( table_name=table_comment[0], table_comment=table_comment[1] ) ) + comment_map[table_comment[0]] = table_comment[1] vector_table = json.dumps( {"table_name": table_comment[0], "table_description": table_comment[1]} @@ -45,11 +84,18 @@ class MysqlSummary(DBSummary): vector_table.encode("utf-8").decode("unicode_escape") ) self.table_columns_info = [] + self.table_columns_json = [] + for table_name in tables: - table_summary = MysqlTableSummary(self.db, name, table_name) + table_summary = MysqlTableSummary(self.db, name, table_name, comment_map) # self.tables[table_name] = table_summary.get_summery() self.tables[table_name] = table_summary.get_columns() self.table_columns_info.append(table_summary.get_columns()) + # self.table_columns_json.append(table_summary.get_summary_json()) + table_profile = "table name:{table_name},table description:{table_comment}".format( + table_name=table_name, table_comment=self.db.get_show_create_table(table_name) + ) + self.table_columns_json.append(table_profile) # self.tables_info.append(table_summary.get_summery()) def get_summery(self): @@ -60,23 +106,29 @@ class MysqlSummary(DBSummary): name=self.name, type=self.type, table_info=";".join(self.tables_info) ) + def get_db_summery(self): + return self.summery.format( + name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000 + ) + def get_table_summary(self): return self.tables def get_table_comments(self): return self.table_comments - def get_columns(self): - return self.table_columns_info + def table_info_json(self): + return self.table_columns_json class MysqlTableSummary(TableSummary): """Get mysql table summary template.""" - def __init__(self, instance, dbname, name): + def __init__(self, instance, dbname, name, comment_map): self.name = name self.dbname = dbname self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" + self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" self.fields = [] self.fields_info = [] self.indexes = [] @@ -100,6 +152,10 @@ class MysqlTableSummary(TableSummary): self.indexes.append(index_summary) self.indexes_info.append(index_summary.get_summery()) + self.json_summery = self.json_summery_template.format( + name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000 + ) + def get_summery(self): return self.summery.format( name=self.name, @@ -111,20 +167,24 @@ class MysqlTableSummary(TableSummary): def get_columns(self): return self.column_summery + def get_summary_json(self): + return self.json_summery + class MysqlFieldsSummary(FieldSummary): """Get mysql field summary template.""" def __init__(self, field): self.name = field[0] - self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ + # self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ + # self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" self.data_type = field[1] self.default_value = field[2] self.is_nullable = field[3] self.comment = field[4] def get_summery(self): - return self.summery.format( + return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format( name=self.name, data_type=self.data_type, is_nullable=self.is_nullable, @@ -138,11 +198,12 @@ class MysqlIndexSummary(IndexSummary): def __init__(self, index): self.name = index[0] - self.summery = """index name:{name}, index bind columns:{bind_fields}""" + # self.summery = """index name:{name}, index bind columns:{bind_fields}""" + self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}' self.bind_fields = index[1] def get_summery(self): - return self.summery.format(name=self.name, bind_fields=self.bind_fields) + return self.summery_template.format(name=self.name, bind_fields=self.bind_fields) if __name__ == "__main__":