mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
fix: merge (#153)
This commit is contained in:
commit
4d3079055c
@ -28,8 +28,12 @@ MAX_POSITION_EMBEDDINGS=4096
|
||||
# FAST_LLM_MODEL=chatglm-6b
|
||||
|
||||
|
||||
### EMBEDDINGS
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
#*******************************************************************#
|
||||
#** EMBEDDING SETTINGS **#
|
||||
#*******************************************************************#
|
||||
EMBEDDING_MODEL=text2vec
|
||||
KNOWLEDGE_CHUNK_SIZE=500
|
||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
|
@ -148,6 +148,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10))
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
|
||||
|
||||
|
@ -35,7 +35,6 @@ LLM_MODEL_CONFIG = {
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxyllm": "proxyllm",
|
||||
}
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
LlamaIndex是一个数据框架,旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器,可以帮助您处理和操作数据。此外,LlamaIndex还提供了一个GPT Index,可以用于数据增强和生成更好的LM模型。
|
@ -82,7 +82,7 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
@ -97,7 +97,6 @@ class GuanacoAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
|
@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect)
|
||||
# "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
@ -45,7 +45,8 @@ class ChatWithDbQA(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
if self.db_name:
|
||||
table_info = DBSummaryClient.get_similar_tables(
|
||||
client = DBSummaryClient()
|
||||
table_info = client.get_similar_tables(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
# table_info = self.database.table_simple_info(self.db_connect)
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.custom.prompt import prompt
|
||||
@ -46,15 +45,13 @@ class ChatNewKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -14,13 +14,23 @@ CFG = Config()
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.default.prompt import prompt
|
||||
@ -42,15 +41,13 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -15,13 +15,23 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -14,7 +14,6 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.url.prompt import prompt
|
||||
@ -40,15 +39,13 @@ class ChatUrlKnowledge(BaseChat):
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=url,
|
||||
file_type="url",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
file_type="url",
|
||||
file_path=url,
|
||||
)
|
||||
|
||||
# url soruce in vector
|
||||
@ -58,7 +55,7 @@ class ChatUrlKnowledge(BaseChat):
|
||||
|
||||
def generate_input_values(self):
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, VECTOR_SEARCH_TOP_K
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
|
@ -14,20 +14,23 @@ CFG = Config()
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。
|
||||
# known information:
|
||||
# {context}
|
||||
# question:
|
||||
# {question}
|
||||
# """
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -59,6 +59,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
||||
from pilot.logs import logger
|
||||
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
def __init__(self) -> None:
|
||||
@ -22,7 +24,7 @@ class KnownLedgeBaseQA:
|
||||
)
|
||||
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
|
||||
search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import threading
|
||||
import traceback
|
||||
import argparse
|
||||
import datetime
|
||||
@ -415,7 +416,7 @@ def build_single_model_ui():
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||
# db_selector.change(fn=db_selector_changed, inputs=db_selector)
|
||||
|
||||
sql_mode = gr.Radio(
|
||||
[
|
||||
@ -619,10 +620,6 @@ def save_vs_name(vs_name):
|
||||
return vs_name
|
||||
|
||||
|
||||
def db_selector_changed(dbname):
|
||||
DBSummaryClient.db_summary_embedding(dbname)
|
||||
|
||||
|
||||
def knowledge_embedding_store(vs_id, files):
|
||||
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
|
||||
@ -635,7 +632,6 @@ def knowledge_embedding_store(vs_id, files):
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@ -647,6 +643,12 @@ def knowledge_embedding_store(vs_id, files):
|
||||
return vs_id
|
||||
|
||||
|
||||
def async_db_summery():
|
||||
client = DBSummaryClient()
|
||||
thread = threading.Thread(target=client.init_db_summary)
|
||||
thread.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -663,7 +665,7 @@ if __name__ == "__main__":
|
||||
cfg = Config()
|
||||
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
|
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List, Optional
|
||||
import chardet
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class EncodeTextLoader(BaseLoader):
|
||||
"""Load text files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
with open(self.file_path, "rb") as f:
|
||||
raw_text = f.read()
|
||||
result = chardet.detect(raw_text)
|
||||
if result["encoding"] is None:
|
||||
text = raw_text.decode("utf-8")
|
||||
else:
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self.file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@ -12,14 +12,12 @@ class CSVEmbedding(SourceEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with csv path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
|
||||
|
@ -1,30 +1,34 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import PyPDFLoader, TextLoader
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
local_persist=True,
|
||||
file_type="default",
|
||||
file_type: Optional[str] = "default",
|
||||
file_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
@ -33,11 +37,9 @@ class KnowledgeEmbedding:
|
||||
self.file_type = file_type
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.local_persist = local_persist
|
||||
if not self.local_persist:
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
|
||||
def knowledge_embedding(self):
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self):
|
||||
@ -50,95 +52,27 @@ class KnowledgeEmbedding:
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
return embedding
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
**knowledge_args,
|
||||
)
|
||||
elif self.file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_type == "default":
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||
|
||||
def vector_exist(self):
|
||||
return self.knowledge_embedding_client.vector_name_exist()
|
||||
|
||||
def knowledge_persist_initialization(self, append_mode):
|
||||
documents = self._load_knownlege(self.file_path)
|
||||
self.vector_client = VectorStoreConnector(
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
self.vector_client.load_document(documents)
|
||||
return self.vector_client
|
||||
return vector_client.similar_search(text, topk)
|
||||
|
||||
def _load_knownlege(self, path):
|
||||
docments = []
|
||||
for root, _, files in os.walk(path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
docs = self._load_file(filename)
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
|
||||
}
|
||||
print("doc is embedding...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
|
||||
def _load_file(self, filename):
|
||||
if filename.lower().endswith(".md"):
|
||||
loader = TextLoader(filename)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
content = markdown.markdown(d.page_content)
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||
tag.extract()
|
||||
docs[i].page_content = soup.get_text()
|
||||
docs[i].page_content = docs[i].page_content.replace("\n", " ")
|
||||
i += 1
|
||||
elif filename.lower().endswith(".pdf"):
|
||||
loader = PyPDFLoader(filename)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
docs = loader.load_and_split(textsplitter)
|
||||
i = 0
|
||||
for d in docs:
|
||||
docs[i].page_content = d.page_content.replace("\n", " ").replace(
|
||||
"<EFBFBD>", ""
|
||||
)
|
||||
i += 1
|
||||
else:
|
||||
loader = TextLoader(filename)
|
||||
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
|
||||
docs = loader.load_and_split(text_splitor)
|
||||
return docs
|
||||
def vector_exist(self):
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
@ -8,27 +8,30 @@ from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MarkdownEmbedding(SourceEmbedding):
|
||||
"""markdown embedding for read markdown document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with markdown path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
# self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = TextLoader(self.file_path)
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
|
@ -5,20 +5,22 @@ from typing import List
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PDFEmbedding(SourceEmbedding):
|
||||
"""pdf embedding for read pdf document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config, encoding):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
@ -26,7 +28,7 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# loader = UnstructuredPaddlePDFLoader(self.file_path)
|
||||
loader = PyPDFLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
|
@ -23,13 +23,11 @@ class SourceEmbedding(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = vector_store_config["embeddings"]
|
||||
|
@ -8,11 +8,10 @@ from pilot import SourceEmbedding, register
|
||||
class StringEmbedding(SourceEmbedding):
|
||||
"""string embedding for read string document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
|
@ -16,11 +16,10 @@ CFG = Config()
|
||||
class URLEmbedding(SourceEmbedding):
|
||||
"""url embedding for read url document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with url path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
@ -29,7 +28,7 @@ class URLEmbedding(SourceEmbedding):
|
||||
loader = WebBaseLoader(web_path=self.file_path)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
|
39
pilot/source_embedding/word_embedding.py
Normal file
39
pilot/source_embedding/word_embedding.py
Normal file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class WordEmbedding(SourceEmbedding):
|
||||
"""word embedding for read word document."""
|
||||
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with word path."""
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from word path."""
|
||||
loader = UnstructuredWordDocumentLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -21,8 +21,10 @@ class DBSummaryClient:
|
||||
, get_similar_tables method(get user query related tables info)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def db_summary_embedding(dbname):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def db_summary_embedding(self, dbname):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
||||
db_summary_client = MysqlSummary(dbname)
|
||||
@ -34,24 +36,21 @@ class DBSummaryClient:
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
db_summary_client.get_summery(),
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
file_path=db_summary_client.get_summery(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if not embedding.vector_name_exist():
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
for vector_table_info in db_summary_client.get_summery():
|
||||
embedding = StringEmbedding(
|
||||
vector_table_info,
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
else:
|
||||
embedding = StringEmbedding(
|
||||
db_summary_client.get_summery(),
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config,
|
||||
file_path=db_summary_client.get_summery(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
for (
|
||||
@ -59,32 +58,24 @@ class DBSummaryClient:
|
||||
table_summary,
|
||||
) in db_summary_client.get_table_summary().items():
|
||||
table_vector_store_config = {
|
||||
"vector_store_name": table_name + "_ts",
|
||||
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
table_summary,
|
||||
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
table_vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
@staticmethod
|
||||
def get_similar_tables(dbname, query, topk):
|
||||
def get_similar_tables(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
@ -104,19 +95,23 @@ class DBSummaryClient:
|
||||
related_table_summaries = []
|
||||
for table in related_tables:
|
||||
vector_store_config = {
|
||||
"vector_store_name": table + "_ts",
|
||||
"embeddings": embeddings,
|
||||
"vector_store_name": dbname + "_" + table + "_ts",
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summery[0].page_content)
|
||||
return related_table_summaries
|
||||
|
||||
def init_db_summary(self):
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
for dbname in dbs:
|
||||
self.db_summary_embedding(dbname)
|
||||
|
||||
|
||||
def _get_llm_response(query, db_input, dbsummary):
|
||||
chat_param = {
|
||||
@ -132,30 +127,3 @@ def _get_llm_response(query, db_input, dbsummary):
|
||||
)
|
||||
res = chat.nostream_call()
|
||||
return json.loads(res)["table"]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10)
|
||||
#
|
||||
# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends".
|
||||
# Response:
|
||||
#
|
||||
# {
|
||||
# "table": ["chat_users"]
|
||||
# }"""
|
||||
# text = text.rstrip().replace("\n","")
|
||||
# start = text.find("{")
|
||||
# end = text.find("}") + 1
|
||||
#
|
||||
# # 从字符串中截取出JSON数据
|
||||
# json_str = text[start:end]
|
||||
#
|
||||
# # 将JSON数据转换为Python中的字典类型
|
||||
# data = json.loads(json_str)
|
||||
# # pattern = r'{s*"table"s*:s*[[^]]*]s*}'
|
||||
# # match = re.search(pattern, text)
|
||||
# # if match:
|
||||
# # json_string = match.group(0)
|
||||
# # # 将JSON字符串转换为Python对象
|
||||
# # json_obj = json.loads(json_string)
|
||||
# # print(summary)
|
||||
|
@ -17,7 +17,6 @@ from langchain.vectorstores import Chroma
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
VECTORE_PATH,
|
||||
)
|
||||
|
||||
@ -41,7 +40,6 @@ class KnownLedge2Vector:
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
|
@ -10,7 +10,6 @@ from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
@ -19,36 +18,30 @@ CFG = Config()
|
||||
|
||||
class LocalKnowledgeInit:
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, vector_store_config) -> None:
|
||||
self.vector_store_config = vector_store_config
|
||||
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
|
||||
def knowledge_persist(self, file_path, append_mode):
|
||||
"""knowledge persist"""
|
||||
kv = KnowledgeEmbedding(
|
||||
file_path=file_path,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
vector_store = kv.knowledge_persist_initialization(append_mode)
|
||||
return vector_store
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector"""
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
for root, _, files in os.walk(file_path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
# docs = self._load_file(filename)
|
||||
ke = KnowledgeEmbedding(
|
||||
file_path=filename,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
client = ke.init_knowledge_embedding()
|
||||
client.source_embedding()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
parser.add_argument("--append", type=bool, default=False)
|
||||
parser.add_argument("--store_type", type=str, default="Chroma")
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
append_mode = args.append
|
||||
@ -56,5 +49,5 @@ if __name__ == "__main__":
|
||||
vector_store_config = {"vector_store_name": vector_name}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
print("your knowledge embedding success...")
|
||||
|
Loading…
Reference in New Issue
Block a user