Merge branch 'summary' into dev

# Conflicts:
#	pilot/common/sql_database.py
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026 2023-06-01 14:13:04 +08:00
commit 1c75dda0a0
30 changed files with 728 additions and 56 deletions

View File

@ -102,4 +102,10 @@ LANGUAGE=en
# ** PROXY_SERVER
#*******************************************************************#
PROXY_API_KEY=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
PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions
PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions
#*******************************************************************#
# ** SUMMARY_CONFIG
#*******************************************************************#
SUMMARY_CONFIG=VECTOR

View File

@ -19,6 +19,7 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
@ -91,7 +92,7 @@ class Database:
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
# self._indexes_in_table_info = indexes_in_table_info
self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
@ -429,3 +430,65 @@ class Database:
return parsed, ttype, sql_type
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
table_name
)
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@character_set_database"))
character_set = cursor.fetchone()[0]
return character_set
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@collation_database"))
collation = cursor.fetchone()[0]
return collation
def get_grants(self):
"""Get grant info."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW GRANTS"))
grants = cursor.fetchall()
return grants
def get_users(self):
"""Get user info."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT user, host FROM mysql.user"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_table_comments(self, database):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{database}'""".format(
database
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@ -142,6 +142,11 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@ -114,32 +114,65 @@ conv_default = Conversation(
sep="###",
)
#
# conv_one_shot = Conversation(
# system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
# "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
# roles=("USER", "Assistant"),
# messages=(
# (
# "USER",
# "What are the key differences between mysql and postgres?",
# ),
# (
# "Assistant",
# "MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) "
# "that have many similarities but also some differences. Here are some key differences: \n"
# "1. Data Types: PostgreSQL has a more extensive set of data types, "
# "including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n"
# "2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), "
# "but PostgreSQL is generally considered to be more strict in enforcing it.\n"
# "3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers,"
# "whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n"
# "4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, "
# "whereas PostgreSQL is known for its robustness and reliability.\n"
# "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
# "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
# "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
# "Both are excellent database management systems, and choosing the right one "
# "for your project requires careful consideration of your application's requirements, performance needs, and scalability.",
# ),
# ),
# offset=2,
# sep_style=SeparatorStyle.SINGLE,
# sep="###",
# )
conv_one_shot = Conversation(
system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "Assistant"),
system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:",
roles=("USER", "ASSISTANT"),
messages=(
(
"USER",
"What are the key differences between mysql and postgres?",
"please query there are how many orders?"
"Querying the table involved in the user input"
"database schema:"
"database name:db_test, database type:MYSQL, table infos:table name:carts,table description:购物车表;table name:categories,table description:商品分类表;table name:chat_groups,table description:群组表;table name:chat_users,table description:聊天用户表;table name:friends,table description:好友表;table name:messages,table description:消息表;table name:orders,table description:订单表;table name:products,table description:商品表;table name:table_test,table description:;table name:users,table description:用户表,"
"You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads"
"""Response Format:
{
"table": ["orders", "products"]
}
""",
),
(
"Assistant",
"MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) "
"that have many similarities but also some differences. Here are some key differences: \n"
"1. Data Types: PostgreSQL has a more extensive set of data types, "
"including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n"
"2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), "
"but PostgreSQL is generally considered to be more strict in enforcing it.\n"
"3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers,"
"whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n"
"4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, "
"whereas PostgreSQL is known for its robustness and reliability.\n"
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
"Both are excellent database management systems, and choosing the right one "
"for your project requires careful consideration of your application's requirements, performance needs, and scalability.",
"""
{
"table": ["orders", "products"]
}
""",
),
),
offset=2,
@ -170,12 +203,12 @@ auto_dbgpt_one_shot = Conversation(
1. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
2. No user assistance
3. Exclusively use the commands listed in double quotes e.g. "command name"
Schema:
Database gpt-user Schema information as follows: users(city,create_time,email,last_login_time,phone,user_name);
Commands:
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
@ -185,7 +218,7 @@ auto_dbgpt_one_shot = Conversation(
6. read_file: Read file, args: "filename": "<filename>"
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
8. db_sql_executor: "Execute SQL in Database.", args: "sql": "<sql>"
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
Response Format:
{
@ -248,6 +281,7 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
{context}
问题:
{question}
"""
# conv_qa_prompt_template = """ Please provide the known information so that I can professionally and briefly answer the user's question. If the answer cannot be obtained from the provided content,
@ -285,4 +319,17 @@ conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1,
"auto_dbgpt_one_shot": auto_dbgpt_one_shot,
}
}
conv_db_summary_templates = """
Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary}
Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
The response format must be JSON, and the key of JSON must be "table".
"""
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)

View File

@ -8,4 +8,5 @@ class ChatScene(Enum):
ChatKnowledge = "chat_default_knowledge"
ChatNewKnowledge = "chat_new_knowledge"
ChatUrlKnowledge = "chat_url_knowledge"
InnerChatDBSummary = "inner_chat_db_summary"
ChatNormal = "chat_normal"

View File

@ -37,11 +37,18 @@ class ChatWithDbAutoExecute(BaseChat):
self.top_k: int = 5
def generate_input_values(self):
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError(
"Could not import DBSummaryClient. "
)
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
# "table_info": self.database.table_simple_info(self.db_connect)
"table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values

View File

@ -37,8 +37,15 @@ class ChatWithDbQA(BaseChat):
table_info = ""
dialect = "mysql"
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError(
"Could not import DBSummaryClient. "
)
if self.db_name:
table_info = self.database.table_simple_info(self.db_connect)
table_info = DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
# table_info = self.database.table_simple_info(self.db_connect)
dialect = self.database.dialect
input_values = {

View File

@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
class ChatFactory(metaclass=Singleton):
@staticmethod

View File

@ -18,7 +18,7 @@ from pilot.configs.model_config import (
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.scene.chat_knowledge.custom.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()

View File

@ -11,6 +11,9 @@ from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容:

View File

@ -0,0 +1,41 @@
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
CFG = Config()
class InnerChatDBSummary (BaseChat):
chat_scene: str = ChatScene.InnerChatDBSummary.value
"""Number of results to return from the query"""
def __init__(self,temperature, max_new_tokens, chat_session_id, user_input, db_select, db_summary):
""" """
super().__init__(temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input)
self.db_name = db_select
self.db_summary = db_summary
def generate_input_values(self):
input_values = {
"db_input": self.db_name,
"db_profile_summary": self.db_summary
}
return input_values
def do_with_prompt_response(self, prompt_response):
return prompt_response
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@ -0,0 +1,22 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return ai_text["table"]
def get_format_instructions(self) -> str:
pass

View File

@ -0,0 +1,58 @@
import builtins
import importlib
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.inner_db_summary.out_parser import NormalChatOutputParser
CFG = Config()
PROMPT_SCENE_DEFINE =""""""
_DEFAULT_TEMPLATE = """
Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary}
Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
The response format must be JSON, and the key of JSON must be "table".
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = {
"table": ["orders", "products"]
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.InnerChatDBSummary.value,
input_variables=["db_profile_summary", "db_input", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -18,7 +18,7 @@ from pilot.configs.model_config import (
VECTOR_SEARCH_TOP_K,
)
from pilot.scene.chat_normal.prompt import prompt
from pilot.scene.chat_knowledge.url.prompt import prompt
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
CFG = Config()
@ -44,6 +44,7 @@ class ChatUrlKnowledge (BaseChat):
}
self.knowledge_embedding_client = KnowledgeEmbedding(
file_path=url,
file_type="url",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,

View File

@ -4,7 +4,7 @@
from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
@ -53,3 +53,17 @@ class KnownLedgeBaseQA:
print("new prompt length:" + str(len(prompt)))
return prompt
@staticmethod
def build_db_summary_prompt(query, db_profile_summary, state):
prompt_template = PromptTemplate(
template=conv_db_summary_templates,
input_variables=["db_input", "db_profile_summary"],
)
# context = [d.page_content for d in docs]
result = prompt_template.format(
db_profile_summary=db_profile_summary, db_input=query
)
state.messages[-2][1] = result
prompt = state.get_prompt()
return prompt

View File

@ -3,16 +3,14 @@
import traceback
import argparse
import datetime
import json
import os
import shutil
import sys
import time
import uuid
import gradio as gr
import requests
from pilot.summary.db_summary_client import DBSummaryClient
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -27,13 +25,9 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_sql_mode,
conversation_types,
chat_mode_title,
@ -41,19 +35,15 @@ from pilot.conversation import (
)
from pilot.common.plugins import scan_plugins
from pilot.prompts.generator import PluginPromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.utils import build_logger
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
@ -75,6 +65,7 @@ vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
# db_summary = {"dbsummary": ""}
priority = {"vicuna-13b": "aaa"}
@ -416,6 +407,8 @@ def build_single_model_ui():
show_label=True,
).style(container=False)
db_selector.change(fn=db_selector_changed, inputs=db_selector)
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
@ -609,6 +602,10 @@ def save_vs_name(vs_name):
return vs_name
def db_selector_changed(dbname):
DBSummaryClient.db_summary_embedding(dbname)
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):

View File

@ -18,12 +18,12 @@ CFG = Config()
class KnowledgeEmbedding:
def __init__(self, file_path, model_name, vector_store_config, local_persist=True):
def __init__(self, file_path, model_name, vector_store_config, local_persist=True, file_type="default"):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
self.file_type = "default"
self.file_type = file_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.local_persist = local_persist
@ -37,7 +37,13 @@ class KnowledgeEmbedding:
self.knowledge_embedding_client.batch_embedding()
def init_knowledge_embedding(self):
if self.file_path.endswith(".pdf"):
if self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_path.endswith(".pdf"):
embedding = PDFEmbedding(
file_path=self.file_path,
model_name=self.model_name,
@ -56,18 +62,15 @@ class KnowledgeEmbedding:
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_type == "default":
embedding = MarkdownEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
return embedding

View File

@ -11,7 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
class PDFEmbedding(SourceEmbedding):
"""yuque embedding for read yuque document."""
"""pdf embedding for read pdf document."""
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with pdf path."""

View File

@ -66,6 +66,9 @@ class SourceEmbedding(ABC):
"""vector store similarity_search"""
return self.vector_client.similar_search(doc, topk)
def vector_name_exist(self):
return self.vector_client.vector_name_exists()
def source_embedding(self):
if "read" in registered_methods:
text = self.read()

View File

@ -0,0 +1,30 @@
from typing import List
from langchain.schema import Document
from pilot import SourceEmbedding, register
class StringEmbedding(SourceEmbedding):
"""string embedding for read string document."""
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with pdf path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
@register
def read(self):
"""Load from String path."""
metadata = {"source": "db_summary"}
return [Document(page_content=self.file_path, metadata=metadata)]
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -13,6 +13,7 @@ class URLEmbedding(SourceEmbedding):
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with url path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config

View File

View File

@ -0,0 +1,31 @@
class DBSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.tables = []
self.metadata = str
def get_summery(self):
return self.summery
class TableSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.fields = []
self.indexes = []
class FieldSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.data_type = None
class IndexSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.bind_fields = []

View File

@ -0,0 +1,176 @@
import json
import uuid
from langchain.embeddings import HuggingFaceEmbeddings, logger
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.source_embedding.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory
CFG = Config()
class DBSummaryClient:
"""db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
, get_similar_tables method(get user query related tables info)
"""
@staticmethod
def db_summary_embedding(dbname):
"""put db profile and table profile summary into vector store"""
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
db_summary_client = MysqlSummary(dbname)
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_profile",
"embeddings": embeddings,
}
embedding = StringEmbedding(
db_summary_client.get_summery(),
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config,
)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery():
embedding = StringEmbedding(
vector_table_info,
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
db_summary_client.get_summery(),
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config,
)
embedding.source_embedding()
for (
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": table_name + "_ts",
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
table_vector_store_config,
)
embedding.source_embedding()
logger.info("db summary embedding success")
@staticmethod
def get_similar_tables(dbname, query, topk):
"""get user query related tables info"""
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_profile",
"embeddings": embeddings,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
local_persist=False,
vector_store_config=vector_store_config,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [json.loads(table_doc.page_content)["table_name"] for table_doc in table_docs]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(query, dbname, table_docs[0].page_content)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
"vector_store_name": table + "_ts",
"embeddings": embeddings,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
local_persist=False,
vector_store_config=vector_store_config,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)
return related_table_summaries
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_input": db_input,
"db_summary": dbsummary,
}
chat_factory = ChatFactory()
chat: BaseChat = chat_factory.get_implementation(ChatScene.InnerChatDBSummary.value(), **chat_param)
return chat.call()
# payload = {
# "model": CFG.LLM_MODEL,
# "prompt": prompt,
# "temperature": float(0.7),
# "max_new_tokens": int(512),
# "stop": state.sep
# if state.sep_style == SeparatorStyle.SINGLE
# else state.sep2,
# }
# headers = {"User-Agent": "dbgpt Client"}
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate"),
# headers=headers,
# json=payload,
# timeout=120,
# )
#
# print(related_tables)
# return related_tables
# except NotCommands as e:
# print("llm response error:" + e.message)
# if __name__ == "__main__":
# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10)
#
# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends".
# Response:
#
# {
# "table": ["chat_users"]
# }"""
# text = text.rstrip().replace("\n","")
# start = text.find("{")
# end = text.find("}") + 1
#
# # 从字符串中截取出JSON数据
# json_str = text[start:end]
#
# # 将JSON数据转换为Python中的字典类型
# data = json.loads(json_str)
# # pattern = r'{s*"table"s*:s*[[^]]*]s*}'
# # match = re.search(pattern, text)
# # if match:
# # json_string = match.group(0)
# # # 将JSON字符串转换为Python对象
# # json_obj = json.loads(json_string)
# # print(summary)

View File

@ -0,0 +1,134 @@
import json
from pilot.configs.config import Config
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary
CFG = Config()
class MysqlSummary(DBSummary):
"""Get mysql summary template."""
def __init__(self, name):
self.name = name
self.type = "MYSQL"
self.summery = (
"""database name:{name}, database type:{type}, table infos:{table_info}"""
)
self.tables = {}
self.tables_info = []
self.vector_tables_info = []
# self.tables_summary = {}
self.db = CFG.local_db
self.db.get_session(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(),
grant=self.db.get_grants(),
charset=self.db.get_charset(),
collation=self.db.get_collation(),
)
tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name)
for table_comment in self.table_comments:
self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1]
)
)
vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]}
)
self.vector_tables_info.append(
vector_table.encode("utf-8").decode("unicode_escape")
)
for table_name in tables:
table_summary = MysqlTableSummary(self.db, name, table_name)
self.tables[table_name] = table_summary.get_summery()
# self.tables_info.append(table_summary.get_summery())
def get_summery(self):
if CFG.SUMMARY_CONFIG == "VECTOR":
return self.vector_tables_info
else:
return self.summery.format(
name=self.name, type=self.type, table_info=";".join(self.tables_info)
)
def get_table_summary(self):
return self.tables
def get_table_comments(self):
return self.table_comments
class MysqlTableSummary(TableSummary):
"""Get mysql table summary template."""
def __init__(self, instance, dbname, name):
self.name = name
self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.fields = []
self.fields_info = []
self.indexes = []
self.indexes_info = []
self.db = instance
fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name)
for field in fields:
field_summary = MysqlFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery())
for index in indexes:
index_summary = MysqlIndexSummary(index)
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery())
def get_summery(self):
return self.summery.format(
name=self.name,
dbname=self.dbname,
fields=";".join(self.fields_info),
indexes=";".join(self.indexes_info),
)
class MysqlFieldsSummary(FieldSummary):
"""Get mysql field summary template."""
def __init__(self, field):
self.name = field[0]
self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summery(self):
return self.summery.format(
name=self.name,
data_type=self.data_type,
is_nullable=self.is_nullable,
default_value=self.default_value,
comment=self.comment,
)
class MysqlIndexSummary(IndexSummary):
"""Get mysql index summary template."""
def __init__(self, index):
self.name = index[0]
self.summery = """index name:{name}, index bind columns:{bind_fields}"""
self.bind_fields = index[1]
def get_summery(self):
return self.summery.format(name=self.name, bind_fields=self.bind_fields)
if __name__ == "__main__":
summary = MysqlSummary("db_test")
print(summary.get_summery())

View File

@ -24,6 +24,11 @@ class ChromaStore(VectorStoreBase):
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
return (
os.path.exists(self.persist_dir) and len(os.listdir(self.persist_dir)) > 0
)
def load_document(self, documents):
logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents]

View File

@ -5,15 +5,22 @@ connector = {"Chroma": ChromaStore, "Milvus": None}
class VectorStoreConnector:
"""vector store connector, can connect different vector db provided load document api and similar search api"""
"""vector store connector, can connect different vector db provided load document api and similar search api."""
def __init__(self, vector_store_type, ctx: {}) -> None:
"""initialize vector store connector."""
self.ctx = ctx
self.connector_class = connector[vector_store_type]
self.client = self.connector_class(ctx)
def load_document(self, docs):
"""load document in vector database."""
self.client.load_document(docs)
def similar_search(self, docs, topk):
"""similar search in vector database."""
return self.client.similar_search(docs, topk)
def vector_name_exists(self):
"""is vector store name exist."""
return self.client.vector_name_exists()

View File

@ -319,5 +319,9 @@ class MilvusStore(VectorStoreBase):
return data[0], ret
def vector_name_exists(self):
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def close(self):
connections.disconnect()

View File

@ -11,5 +11,10 @@ class VectorStoreBase(ABC):
@abstractmethod
def similar_search(self, text, topk) -> None:
"""Initialize schema in vector database."""
"""similar search in vector database."""
pass
@abstractmethod
def vector_name_exists(self, text, topk) -> None:
"""is vector store name exist."""
pass