diff --git a/.env.template b/.env.template index d49053e8c..2ed7932f3 100644 --- a/.env.template +++ b/.env.template @@ -102,4 +102,10 @@ LANGUAGE=en # ** PROXY_SERVER #*******************************************************************# PROXY_API_KEY=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 -PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions \ No newline at end of file +PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions + + +#*******************************************************************# +# ** SUMMARY_CONFIG +#*******************************************************************# +SUMMARY_CONFIG=VECTOR \ No newline at end of file diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 0c5fcb313..cc09e4328 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -19,6 +19,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable from sqlalchemy.orm import sessionmaker, scoped_session + def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: return ( f'Name: {index["name"]}, Unique: {index["unique"]},' @@ -91,7 +92,7 @@ class Database: # raise TypeError("sample_rows_in_table_info must be an integer") # # self._sample_rows_in_table_info = sample_rows_in_table_info - # self._indexes_in_table_info = indexes_in_table_info + self._indexes_in_table_info = indexes_in_table_info # # self._custom_table_info = custom_table_info # if self._custom_table_info: @@ -429,3 +430,65 @@ class Database: return parsed, ttype, sql_type + + def get_indexes(self, table_name): + """Get table indexes about specified table.""" + session = self._db_sessions() + cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}")) + indexes = cursor.fetchall() + return [(index[2], index[4]) for index in indexes] + + def get_fields(self, table_name): + """Get column fields about specified table.""" + session = self._db_sessions() + cursor = session.execute( + text( + f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format( + table_name + ) + ) + ) + fields = cursor.fetchall() + return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] + + def get_charset(self): + """Get character_set.""" + session = self._db_sessions() + cursor = session.execute(text(f"SELECT @@character_set_database")) + character_set = cursor.fetchone()[0] + return character_set + + def get_collation(self): + """Get collation.""" + session = self._db_sessions() + cursor = session.execute(text(f"SELECT @@collation_database")) + collation = cursor.fetchone()[0] + return collation + + def get_grants(self): + """Get grant info.""" + session = self._db_sessions() + cursor = session.execute(text(f"SHOW GRANTS")) + grants = cursor.fetchall() + return grants + + def get_users(self): + """Get user info.""" + session = self._db_sessions() + cursor = session.execute(text(f"SELECT user, host FROM mysql.user")) + users = cursor.fetchall() + return [(user[0], user[1]) for user in users] + + def get_table_comments(self, database): + session = self._db_sessions() + cursor = session.execute( + text( + f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{database}'""".format( + database + ) + ) + ) + table_comments = cursor.fetchall() + return [ + (table_comment[0], table_comment[1]) for table_comment in table_comments + ] diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 8618651e4..c39f272f5 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -142,6 +142,11 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + ### EMBEDDING Configuration + self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + ### SUMMARY_CONFIG Configuration + self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/conversation.py b/pilot/conversation.py index 3fe648529..f03c13e31 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -114,32 +114,65 @@ conv_default = Conversation( sep="###", ) +# +# conv_one_shot = Conversation( +# system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " +# "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", +# roles=("USER", "Assistant"), +# messages=( +# ( +# "USER", +# "What are the key differences between mysql and postgres?", +# ), +# ( +# "Assistant", +# "MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) " +# "that have many similarities but also some differences. Here are some key differences: \n" +# "1. Data Types: PostgreSQL has a more extensive set of data types, " +# "including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n" +# "2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), " +# "but PostgreSQL is generally considered to be more strict in enforcing it.\n" +# "3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers," +# "whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n" +# "4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, " +# "whereas PostgreSQL is known for its robustness and reliability.\n" +# "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, " +# "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n" +# "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. " +# "Both are excellent database management systems, and choosing the right one " +# "for your project requires careful consideration of your application's requirements, performance needs, and scalability.", +# ), +# ), +# offset=2, +# sep_style=SeparatorStyle.SINGLE, +# sep="###", +# ) + + conv_one_shot = Conversation( - system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", - roles=("USER", "Assistant"), + system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:", + roles=("USER", "ASSISTANT"), messages=( ( "USER", - "What are the key differences between mysql and postgres?", + "please query there are how many orders?" + "Querying the table involved in the user input?" + "database schema:" + "database name:db_test, database type:MYSQL, table infos:table name:carts,table description:购物车表;table name:categories,table description:商品分类表;table name:chat_groups,table description:群组表;table name:chat_users,table description:聊天用户表;table name:friends,table description:好友表;table name:messages,table description:消息表;table name:orders,table description:订单表;table name:products,table description:商品表;table name:table_test,table description:;table name:users,table description:用户表," + "You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads" + """Response Format: + { + "table": ["orders", "products"] + } + """, ), ( "Assistant", - "MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) " - "that have many similarities but also some differences. Here are some key differences: \n" - "1. Data Types: PostgreSQL has a more extensive set of data types, " - "including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n" - "2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), " - "but PostgreSQL is generally considered to be more strict in enforcing it.\n" - "3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers," - "whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n" - "4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, " - "whereas PostgreSQL is known for its robustness and reliability.\n" - "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, " - "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n" - "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. " - "Both are excellent database management systems, and choosing the right one " - "for your project requires careful consideration of your application's requirements, performance needs, and scalability.", + """ + { + "table": ["orders", "products"] + } + """, ), ), offset=2, @@ -170,12 +203,12 @@ auto_dbgpt_one_shot = Conversation( 1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember. 2. No user assistance 3. Exclusively use the commands listed in double quotes e.g. "command name" - - + + Schema: Database gpt-user Schema information as follows: users(city,create_time,email,last_login_time,phone,user_name); - - + + Commands: 1. analyze_code: Analyze Code, args: "code": "" 2. execute_python_file: Execute Python File, args: "filename": "" @@ -185,7 +218,7 @@ auto_dbgpt_one_shot = Conversation( 6. read_file: Read file, args: "filename": "" 7. write_to_file: Write to file, args: "filename": "", "text": "" 8. db_sql_executor: "Execute SQL in Database.", args: "sql": "" - + You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads Response Format: { @@ -248,6 +281,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 {context} 问题: {question} + """ # conv_qa_prompt_template = """ Please provide the known information so that I can professionally and briefly answer the user's question. If the answer cannot be obtained from the provided content, @@ -285,4 +319,17 @@ conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, "auto_dbgpt_one_shot": auto_dbgpt_one_shot, -} \ No newline at end of file +} + +conv_db_summary_templates = """ +Based on the following known database information?, answer which tables are involved in the user input. +Known database information:{db_profile_summary} +Input:{db_input} +You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads +The response format must be JSON, and the key of JSON must be "table". + +""" + +if __name__ == "__main__": + message = gen_sqlgen_conversation("dbgpt") + print(message) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 21f605fed..e301a14de 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -8,4 +8,5 @@ class ChatScene(Enum): ChatKnowledge = "chat_default_knowledge" ChatNewKnowledge = "chat_new_knowledge" ChatUrlKnowledge = "chat_url_knowledge" + InnerChatDBSummary = "inner_chat_db_summary" ChatNormal = "chat_normal" diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 254fb33d2..0ef8bc701 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -37,11 +37,18 @@ class ChatWithDbAutoExecute(BaseChat): self.top_k: int = 5 def generate_input_values(self): + try: + from pilot.summary.db_summary_client import DBSummaryClient + except ImportError: + raise ValueError( + "Could not import DBSummaryClient. " + ) input_values = { "input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect) + # "table_info": self.database.table_simple_info(self.db_connect) + "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index dcf10d782..74f83ddaa 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -37,8 +37,15 @@ class ChatWithDbQA(BaseChat): table_info = "" dialect = "mysql" + try: + from pilot.summary.db_summary_client import DBSummaryClient + except ImportError: + raise ValueError( + "Could not import DBSummaryClient. " + ) if self.db_name: - table_info = self.database.table_simple_info(self.db_connect) + table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + # table_info = self.database.table_simple_info(self.db_connect) dialect = self.database.dialect input_values = { diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 7a346cbda..2e67df66c 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge +from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary class ChatFactory(metaclass=Singleton): @staticmethod diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 7b9a11f85..f0db3df8f 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -18,7 +18,7 @@ from pilot.configs.model_config import ( VECTOR_SEARCH_TOP_K, ) -from pilot.scene.chat_normal.prompt import prompt +from pilot.scene.chat_knowledge.custom.prompt import prompt from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding CFG = Config() diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index 175deaddb..ab96c1703 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -11,6 +11,9 @@ from pilot.scene.chat_normal.out_parser import NormalChatOutputParser CFG = Config() +PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" + + _DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: diff --git a/pilot/scene/chat_knowledge/inner_db_summary/__init__.py b/pilot/scene/chat_knowledge/inner_db_summary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py new file mode 100644 index 000000000..cbdc44538 --- /dev/null +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -0,0 +1,41 @@ + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config + +from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt + +CFG = Config() + + +class InnerChatDBSummary (BaseChat): + chat_scene: str = ChatScene.InnerChatDBSummary.value + + """Number of results to return from the query""" + + def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary): + """ """ + super().__init__(temperature=temperature, + max_new_tokens=max_new_tokens, + chat_mode=ChatScene.InnerChatDBSummary, + chat_session_id=chat_session_id, + current_user_input=user_input) + self.db_name = db_select + self.db_summary = db_summary + + + def generate_input_values(self): + input_values = { + "db_input": self.db_name, + "db_profile_summary": self.db_summary + } + return input_values + + def do_with_prompt_response(self, prompt_response): + return prompt_response + + + + @property + def chat_type(self) -> str: + return ChatScene.InnerChatDBSummary.value diff --git a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py new file mode 100644 index 000000000..0d2a7e49d --- /dev/null +++ b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py @@ -0,0 +1,22 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + +class NormalChatOutputParser(BaseOutputParser): + + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def parse_view_response(self, ai_text) -> str: + return ai_text["table"] + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/inner_db_summary/prompt.py b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py new file mode 100644 index 000000000..739bf0364 --- /dev/null +++ b/pilot/scene/chat_knowledge/inner_db_summary/prompt.py @@ -0,0 +1,58 @@ +import builtins +import importlib +import json + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser + + +CFG = Config() + +PROMPT_SCENE_DEFINE ="""""" + +_DEFAULT_TEMPLATE = """ +Based on the following known database information?, answer which tables are involved in the user input. +Known database information:{db_profile_summary} +Input:{db_input} +You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads +The response format must be JSON, and the key of JSON must be "table". + +""" +PROMPT_RESPONSE = """You must respond in JSON format as following format: +{response} + +Ensure the response is correct json and can be parsed by Python json.loads +""" + + + +RESPONSE_FORMAT = { + "table": ["orders", "products"] + } + + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +prompt = PromptTemplate( + template_scene=ChatScene.InnerChatDBSummary.value, + input_variables=["db_profile_summary", "db_input", "response"], + response_format=json.dumps(RESPONSE_FORMAT, indent=4), + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) + + diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 0c54f6001..0666de9e1 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -18,7 +18,7 @@ from pilot.configs.model_config import ( VECTOR_SEARCH_TOP_K, ) -from pilot.scene.chat_normal.prompt import prompt +from pilot.scene.chat_knowledge.url.prompt import prompt from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding CFG = Config() @@ -44,6 +44,7 @@ class ChatUrlKnowledge (BaseChat): } self.knowledge_embedding_client = KnowledgeEmbedding( file_path=url, + file_type="url", model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, vector_store_config=vector_store_config, diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index e64f68097..ff794322f 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -4,7 +4,7 @@ from langchain.prompts import PromptTemplate from pilot.configs.model_config import VECTOR_SEARCH_TOP_K -from pilot.conversation import conv_qa_prompt_template +from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates from pilot.logs import logger from pilot.model.vicuna_llm import VicunaLLM from pilot.vector_store.file_loader import KnownLedge2Vector @@ -53,3 +53,17 @@ class KnownLedgeBaseQA: print("new prompt length:" + str(len(prompt))) return prompt + + @staticmethod + def build_db_summary_prompt(query, db_profile_summary, state): + prompt_template = PromptTemplate( + template=conv_db_summary_templates, + input_variables=["db_input", "db_profile_summary"], + ) + # context = [d.page_content for d in docs] + result = prompt_template.format( + db_profile_summary=db_profile_summary, db_input=query + ) + state.messages[-2][1] = result + prompt = state.get_prompt() + return prompt diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 240dba201..132944fb6 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -3,16 +3,14 @@ import traceback import argparse import datetime -import json import os import shutil import sys -import time import uuid import gradio as gr -import requests +from pilot.summary.db_summary_client import DBSummaryClient ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -27,13 +25,9 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, LOGDIR, - VECTOR_SEARCH_TOP_K, ) from pilot.conversation import ( - SeparatorStyle, - conv_qa_prompt_template, - conv_templates, conversation_sql_mode, conversation_types, chat_mode_title, @@ -41,19 +35,15 @@ from pilot.conversation import ( ) from pilot.common.plugins import scan_plugins -from pilot.prompts.generator import PluginPromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot -from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding -from pilot.utils import build_logger, server_error_msg +from pilot.utils import build_logger from pilot.vector_store.extract_tovec import ( get_vector_storelist, knownledge_tovec_st, - load_knownledge_from_doc, ) -from pilot.commands.command import execute_ai_response_json from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.language.translation_handler import get_lang_text @@ -75,6 +65,7 @@ vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist() autogpt = False vector_store_client = None vector_store_name = {"vs_name": ""} +# db_summary = {"dbsummary": ""} priority = {"vicuna-13b": "aaa"} @@ -416,6 +407,8 @@ def build_single_model_ui(): show_label=True, ).style(container=False) + db_selector.change(fn=db_selector_changed, inputs=db_selector) + sql_mode = gr.Radio( [ get_lang_text("sql_generate_mode_direct"), @@ -609,6 +602,10 @@ def save_vs_name(vs_name): return vs_name +def db_selector_changed(dbname): + DBSummaryClient.db_summary_embedding(dbname) + + def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 8f411657d..aefddd848 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -18,12 +18,12 @@ CFG = Config() class KnowledgeEmbedding: - def __init__(self, file_path, model_name, vector_store_config, local_persist=True): + def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config - self.file_type = "default" + self.file_type = file_type self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.vector_store_config["embeddings"] = self.embeddings self.local_persist = local_persist @@ -37,7 +37,13 @@ class KnowledgeEmbedding: self.knowledge_embedding_client.batch_embedding() def init_knowledge_embedding(self): - if self.file_path.endswith(".pdf"): + if self.file_type == "url": + embedding = URLEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) + elif self.file_path.endswith(".pdf"): embedding = PDFEmbedding( file_path=self.file_path, model_name=self.model_name, @@ -56,18 +62,15 @@ class KnowledgeEmbedding: model_name=self.model_name, vector_store_config=self.vector_store_config, ) + + elif self.file_type == "default": embedding = MarkdownEmbedding( file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, ) - elif self.file_type == "url": - embedding = URLEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) + return embedding diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index c76cf65d2..de1767c51 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -11,7 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter class PDFEmbedding(SourceEmbedding): - """yuque embedding for read yuque document.""" + """pdf embedding for read pdf document.""" def __init__(self, file_path, model_name, vector_store_config): """Initialize with pdf path.""" diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 94e48e79e..7db92ea9b 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -66,6 +66,9 @@ class SourceEmbedding(ABC): """vector store similarity_search""" return self.vector_client.similar_search(doc, topk) + def vector_name_exist(self): + return self.vector_client.vector_name_exists() + def source_embedding(self): if "read" in registered_methods: text = self.read() diff --git a/pilot/source_embedding/string_embedding.py b/pilot/source_embedding/string_embedding.py new file mode 100644 index 000000000..b4d7b1228 --- /dev/null +++ b/pilot/source_embedding/string_embedding.py @@ -0,0 +1,30 @@ +from typing import List + +from langchain.schema import Document + +from pilot import SourceEmbedding, register + + +class StringEmbedding(SourceEmbedding): + """string embedding for read string document.""" + + def __init__(self, file_path, model_name, vector_store_config): + """Initialize with pdf path.""" + super().__init__(file_path, model_name, vector_store_config) + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from String path.""" + metadata = {"source": "db_summary"} + return [Document(page_content=self.file_path, metadata=metadata)] + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py index 59eef19e7..e74defa80 100644 --- a/pilot/source_embedding/url_embedding.py +++ b/pilot/source_embedding/url_embedding.py @@ -13,6 +13,7 @@ class URLEmbedding(SourceEmbedding): def __init__(self, file_path, model_name, vector_store_config): """Initialize with url path.""" + super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config diff --git a/pilot/summary/__init__.py b/pilot/summary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/summary/db_summary.py b/pilot/summary/db_summary.py new file mode 100644 index 000000000..30f5e2e48 --- /dev/null +++ b/pilot/summary/db_summary.py @@ -0,0 +1,31 @@ +class DBSummary: + def __init__(self, name): + self.name = name + self.summery = None + self.tables = [] + self.metadata = str + + def get_summery(self): + return self.summery + + +class TableSummary: + def __init__(self, name): + self.name = name + self.summery = None + self.fields = [] + self.indexes = [] + + +class FieldSummary: + def __init__(self, name): + self.name = name + self.summery = None + self.data_type = None + + +class IndexSummary: + def __init__(self, name): + self.name = name + self.summery = None + self.bind_fields = [] diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py new file mode 100644 index 000000000..a34b87a93 --- /dev/null +++ b/pilot/summary/db_summary_client.py @@ -0,0 +1,176 @@ +import json +import uuid + +from langchain.embeddings import HuggingFaceEmbeddings, logger + +from pilot.configs.config import Config +from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.scene.base import ChatScene +from pilot.scene.base_chat import BaseChat +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding +from pilot.source_embedding.string_embedding import StringEmbedding +from pilot.summary.mysql_db_summary import MysqlSummary +from pilot.scene.chat_factory import ChatFactory + +CFG = Config() + + +class DBSummaryClient: + """db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store) + , get_similar_tables method(get user query related tables info) + """ + + @staticmethod + def db_summary_embedding(dbname): + """put db profile and table profile summary into vector store""" + if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None: + db_summary_client = MysqlSummary(dbname) + embeddings = HuggingFaceEmbeddings( + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + ) + vector_store_config = { + "vector_store_name": dbname + "_profile", + "embeddings": embeddings, + } + embedding = StringEmbedding( + db_summary_client.get_summery(), + LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config, + ) + if not embedding.vector_name_exist(): + if CFG.SUMMARY_CONFIG == "FAST": + for vector_table_info in db_summary_client.get_summery(): + embedding = StringEmbedding( + vector_table_info, + LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config, + ) + embedding.source_embedding() + else: + embedding = StringEmbedding( + db_summary_client.get_summery(), + LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config, + ) + embedding.source_embedding() + for ( + table_name, + table_summary, + ) in db_summary_client.get_table_summary().items(): + table_vector_store_config = { + "vector_store_name": table_name + "_ts", + "embeddings": embeddings, + } + embedding = StringEmbedding( + table_summary, + LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + table_vector_store_config, + ) + embedding.source_embedding() + + logger.info("db summary embedding success") + + @staticmethod + def get_similar_tables(dbname, query, topk): + """get user query related tables info""" + embeddings = HuggingFaceEmbeddings( + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + ) + vector_store_config = { + "vector_store_name": dbname + "_profile", + "embeddings": embeddings, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + local_persist=False, + vector_store_config=vector_store_config, + ) + if CFG.SUMMARY_CONFIG == "FAST": + table_docs = knowledge_embedding_client.similar_search(query, topk) + related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs] + else: + table_docs = knowledge_embedding_client.similar_search(query, 1) + # prompt = KnownLedgeBaseQA.build_db_summary_prompt( + # query, table_docs[0].page_content + # ) + related_tables = _get_llm_response(query, dbname, table_docs[0].page_content) + related_table_summaries = [] + for table in related_tables: + vector_store_config = { + "vector_store_name": table + "_ts", + "embeddings": embeddings, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + local_persist=False, + vector_store_config=vector_store_config, + ) + table_summery = knowledge_embedding_client.similar_search(query, 1) + related_table_summaries.append(table_summery[0].page_content) + return related_table_summaries + + +def _get_llm_response(query, db_input, dbsummary): + chat_param = { + "temperature": 0.7, + "max_new_tokens": 512, + "chat_session_id": uuid.uuid1(), + "user_input": query, + "db_input": db_input, + "db_summary": dbsummary, + } + chat_factory = ChatFactory() + chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param) + + return chat.call() + # payload = { + # "model": CFG.LLM_MODEL, + # "prompt": prompt, + # "temperature": float(0.7), + # "max_new_tokens": int(512), + # "stop": state.sep + # if state.sep_style == SeparatorStyle.SINGLE + # else state.sep2, + # } + # headers = {"User-Agent": "dbgpt Client"} + # response = requests.post( + # urljoin(CFG.MODEL_SERVER, "generate"), + # headers=headers, + # json=payload, + # timeout=120, + # ) + # + # print(related_tables) + # return related_tables + # except NotCommands as e: + # print("llm response error:" + e.message) + + + +# if __name__ == "__main__": +# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10) +# +# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends". +# Response: +# +# { +# "table": ["chat_users"] +# }""" +# text = text.rstrip().replace("\n","") +# start = text.find("{") +# end = text.find("}") + 1 +# +# # 从字符串中截取出JSON数据 +# json_str = text[start:end] +# +# # 将JSON数据转换为Python中的字典类型 +# data = json.loads(json_str) +# # pattern = r'{s*"table"s*:s*[[^]]*]s*}' +# # match = re.search(pattern, text) +# # if match: +# # json_string = match.group(0) +# # # 将JSON字符串转换为Python对象 +# # json_obj = json.loads(json_string) +# # print(summary) diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py new file mode 100644 index 000000000..e14aad9a3 --- /dev/null +++ b/pilot/summary/mysql_db_summary.py @@ -0,0 +1,134 @@ +import json + +from pilot.configs.config import Config +from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary + +CFG = Config() + + +class MysqlSummary(DBSummary): + """Get mysql summary template.""" + + def __init__(self, name): + self.name = name + self.type = "MYSQL" + self.summery = ( + """database name:{name}, database type:{type}, table infos:{table_info}""" + ) + self.tables = {} + self.tables_info = [] + self.vector_tables_info = [] + # self.tables_summary = {} + + self.db = CFG.local_db + self.db.get_session(name) + self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( + users=self.db.get_users(), + grant=self.db.get_grants(), + charset=self.db.get_charset(), + collation=self.db.get_collation(), + ) + tables = self.db.get_table_names() + self.table_comments = self.db.get_table_comments(name) + for table_comment in self.table_comments: + self.tables_info.append( + "table name:{table_name},table description:{table_comment}".format( + table_name=table_comment[0], table_comment=table_comment[1] + ) + ) + vector_table = json.dumps( + {"table_name": table_comment[0], "table_description": table_comment[1]} + ) + self.vector_tables_info.append( + vector_table.encode("utf-8").decode("unicode_escape") + ) + + for table_name in tables: + table_summary = MysqlTableSummary(self.db, name, table_name) + self.tables[table_name] = table_summary.get_summery() + # self.tables_info.append(table_summary.get_summery()) + + def get_summery(self): + if CFG.SUMMARY_CONFIG == "VECTOR": + return self.vector_tables_info + else: + return self.summery.format( + name=self.name, type=self.type, table_info=";".join(self.tables_info) + ) + + def get_table_summary(self): + return self.tables + + def get_table_comments(self): + return self.table_comments + + +class MysqlTableSummary(TableSummary): + """Get mysql table summary template.""" + + def __init__(self, instance, dbname, name): + self.name = name + self.dbname = dbname + self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" + self.fields = [] + self.fields_info = [] + self.indexes = [] + self.indexes_info = [] + self.db = instance + fields = self.db.get_fields(name) + indexes = self.db.get_indexes(name) + for field in fields: + field_summary = MysqlFieldsSummary(field) + self.fields.append(field_summary) + self.fields_info.append(field_summary.get_summery()) + + for index in indexes: + index_summary = MysqlIndexSummary(index) + self.indexes.append(index_summary) + self.indexes_info.append(index_summary.get_summery()) + + def get_summery(self): + return self.summery.format( + name=self.name, + dbname=self.dbname, + fields=";".join(self.fields_info), + indexes=";".join(self.indexes_info), + ) + + +class MysqlFieldsSummary(FieldSummary): + """Get mysql field summary template.""" + + def __init__(self, field): + self.name = field[0] + self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ + self.data_type = field[1] + self.default_value = field[2] + self.is_nullable = field[3] + self.comment = field[4] + + def get_summery(self): + return self.summery.format( + name=self.name, + data_type=self.data_type, + is_nullable=self.is_nullable, + default_value=self.default_value, + comment=self.comment, + ) + + +class MysqlIndexSummary(IndexSummary): + """Get mysql index summary template.""" + + def __init__(self, index): + self.name = index[0] + self.summery = """index name:{name}, index bind columns:{bind_fields}""" + self.bind_fields = index[1] + + def get_summery(self): + return self.summery.format(name=self.name, bind_fields=self.bind_fields) + + +if __name__ == "__main__": + summary = MysqlSummary("db_test") + print(summary.get_summery()) diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index 1ec9e8b04..3a9de6874 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -24,6 +24,11 @@ class ChromaStore(VectorStoreBase): logger.info("ChromaStore similar search") return self.vector_store_client.similarity_search(text, topk) + def vector_name_exists(self): + return ( + os.path.exists(self.persist_dir) and len(os.listdir(self.persist_dir)) > 0 + ) + def load_document(self, documents): logger.info("ChromaStore load document") texts = [doc.page_content for doc in documents] diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 3ff473f1e..6c7028856 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -5,15 +5,22 @@ connector = {"Chroma": ChromaStore, "Milvus": None} class VectorStoreConnector: - """vector store connector, can connect different vector db provided load document api and similar search api""" + """vector store connector, can connect different vector db provided load document api and similar search api.""" def __init__(self, vector_store_type, ctx: {}) -> None: + """initialize vector store connector.""" self.ctx = ctx self.connector_class = connector[vector_store_type] self.client = self.connector_class(ctx) def load_document(self, docs): + """load document in vector database.""" self.client.load_document(docs) def similar_search(self, docs, topk): + """similar search in vector database.""" return self.client.similar_search(docs, topk) + + def vector_name_exists(self): + """is vector store name exist.""" + return self.client.vector_name_exists() diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index c9fc985e4..4535ea30a 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -319,5 +319,9 @@ class MilvusStore(VectorStoreBase): return data[0], ret + def vector_name_exists(self): + """is vector store name exist.""" + return utility.has_collection(self.collection_name) + def close(self): connections.disconnect() diff --git a/pilot/vector_store/vector_store_base.py b/pilot/vector_store/vector_store_base.py index 70888f5aa..0108e06b1 100644 --- a/pilot/vector_store/vector_store_base.py +++ b/pilot/vector_store/vector_store_base.py @@ -11,5 +11,10 @@ class VectorStoreBase(ABC): @abstractmethod def similar_search(self, text, topk) -> None: - """Initialize schema in vector database.""" + """similar search in vector database.""" + pass + + @abstractmethod + def vector_name_exists(self, text, topk) -> None: + """is vector store name exist.""" pass