fix: merge (#153)

This commit is contained in:
csunny 2023-06-05 22:29:40 +08:00
commit 4d3079055c
29 changed files with 213 additions and 230 deletions

View File

@ -28,8 +28,12 @@ MAX_POSITION_EMBEDDINGS=4096
# FAST_LLM_MODEL=chatglm-6b # FAST_LLM_MODEL=chatglm-6b
### EMBEDDINGS #*******************************************************************#
## EMBEDDING_MODEL - Model to use for creating embeddings #** EMBEDDING SETTINGS **#
#*******************************************************************#
EMBEDDING_MODEL=text2vec
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs ## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs ## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=all-MiniLM-L6-v2 # EMBEDDING_MODEL=all-MiniLM-L6-v2

View File

@ -148,6 +148,8 @@ class Config(metaclass=Singleton):
### EMBEDDING Configuration ### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10))
### SUMMARY_CONFIG Configuration ### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")

View File

@ -35,7 +35,6 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

View File

@ -0,0 +1 @@
LlamaIndex是一个数据框架旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器可以帮助您处理和操作数据。此外LlamaIndex还提供了一个GPT Index可以用于数据增强和生成更好的LM模型。

View File

@ -82,7 +82,7 @@ class ChatGLMAdapater(BaseLLMAdaper):
) )
return model, tokenizer return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
@ -97,7 +97,6 @@ class GuanacoAdapter(BaseLLMAdaper):
return model, tokenizer return model, tokenizer
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""

View File

@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat):
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), "top_k": str(self.top_k),
"dialect": self.database.dialect, "dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect) "table_info": self.database.table_simple_info(self.db_connect)
# "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
} }
return input_values return input_values

View File

@ -45,7 +45,8 @@ class ChatWithDbQA(BaseChat):
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
if self.db_name: if self.db_name:
table_info = DBSummaryClient.get_similar_tables( client = DBSummaryClient()
table_info = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k dbname=self.db_name, query=self.current_user_input, topk=self.top_k
) )
# table_info = self.database.table_simple_info(self.db_connect) # table_info = self.database.table_simple_info(self.db_connect)

View File

@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR, LOGDIR,
VECTOR_SEARCH_TOP_K,
) )
from pilot.scene.chat_knowledge.custom.prompt import prompt from pilot.scene.chat_knowledge.custom.prompt import prompt
@ -46,15 +45,13 @@ class ChatNewKnowledge(BaseChat):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search( docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
) )
context = [d.page_content for d in docs] context = [d.page_content for d in docs]
context = context[:2000] context = context[:2000]

View File

@ -14,13 +14,23 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, _DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容: 已知内容:
{context} {context}
问题: 问题:
{question} {question}
""" """
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
known information:
{context}
question:
{question}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR, LOGDIR,
VECTOR_SEARCH_TOP_K,
) )
from pilot.scene.chat_knowledge.default.prompt import prompt from pilot.scene.chat_knowledge.default.prompt import prompt
@ -42,15 +41,13 @@ class ChatDefaultKnowledge(BaseChat):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search( docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
) )
context = [d.page_content for d in docs] context = [d.page_content for d in docs]
context = context[:2000] context = context[:2000]

View File

@ -15,13 +15,23 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, _DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容: 已知内容:
{context} {context}
问题: 问题:
{question} {question}
""" """
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
known information:
{context}
question:
{question}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -14,7 +14,6 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR, LOGDIR,
VECTOR_SEARCH_TOP_K,
) )
from pilot.scene.chat_knowledge.url.prompt import prompt from pilot.scene.chat_knowledge.url.prompt import prompt
@ -40,15 +39,13 @@ class ChatUrlKnowledge(BaseChat):
self.url = url self.url = url
vector_store_config = { vector_store_config = {
"vector_store_name": url, "vector_store_name": url,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
file_path=url,
file_type="url",
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
file_type="url",
file_path=url,
) )
# url soruce in vector # url soruce in vector
@ -58,7 +55,7 @@ class ChatUrlKnowledge(BaseChat):
def generate_input_values(self): def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search( docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, VECTOR_SEARCH_TOP_K self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
) )
context = [d.page_content for d in docs] context = [d.page_content for d in docs]
context = context[:2000] context = context[:2000]

View File

@ -14,20 +14,23 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。
# known information:
# {context}
# question:
# {question}
# """
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造
已知内容: 已知内容:
{context} {context}
问题: 问题:
{question} {question}
""" """
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
known information:
{context}
question:
{question}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -59,6 +59,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return chatglm_generate_stream return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5""" """Model chat adapter for CodeT5"""

View File

@ -3,12 +3,14 @@
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
from pilot.logs import logger from pilot.logs import logger
from pilot.model.llm_out.vicuna_llm import VicunaLLM from pilot.model.llm_out.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector from pilot.vector_store.file_loader import KnownLedge2Vector
CFG = Config()
class KnownLedgeBaseQA: class KnownLedgeBaseQA:
def __init__(self) -> None: def __init__(self) -> None:
@ -22,7 +24,7 @@ class KnownLedgeBaseQA:
) )
retriever = self.vector_store.as_retriever( retriever = self.vector_store.as_retriever(
search_kwargs={"k": VECTOR_SEARCH_TOP_K} search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
) )
docs = retriever.get_relevant_documents(query=query) docs = retriever.get_relevant_documents(query=query)

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import threading
import traceback import traceback
import argparse import argparse
import datetime import datetime
@ -415,7 +416,7 @@ def build_single_model_ui():
show_label=True, show_label=True,
).style(container=False) ).style(container=False)
db_selector.change(fn=db_selector_changed, inputs=db_selector) # db_selector.change(fn=db_selector_changed, inputs=db_selector)
sql_mode = gr.Radio( sql_mode = gr.Radio(
[ [
@ -619,10 +620,6 @@ def save_vs_name(vs_name):
return vs_name return vs_name
def db_selector_changed(dbname):
DBSummaryClient.db_summary_embedding(dbname)
def knowledge_embedding_store(vs_id, files): def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id) # vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
@ -635,7 +632,6 @@ def knowledge_embedding_store(vs_id, files):
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config={ vector_store_config={
"vector_store_name": vector_store_name["vs_name"], "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -647,6 +643,12 @@ def knowledge_embedding_store(vs_id, files):
return vs_id return vs_id
def async_db_summery():
client = DBSummaryClient()
thread = threading.Thread(target=client.init_db_summary)
thread.start()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -663,7 +665,7 @@ if __name__ == "__main__":
cfg = Config() cfg = Config()
dbs = cfg.local_db.get_database_list() dbs = cfg.local_db.get_database_list()
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令 # 加载插件可执行命令

View File

@ -0,0 +1,26 @@
from typing import List, Optional
import chardet
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
class EncodeTextLoader(BaseLoader):
"""Load text files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
with open(self.file_path, "rb") as f:
raw_text = f.read()
result = chardet.detect(raw_text)
if result["encoding"] is None:
text = raw_text.decode("utf-8")
else:
text = raw_text.decode(result["encoding"])
metadata = {"source": self.file_path}
return [Document(page_content=text, metadata=metadata)]

View File

@ -12,14 +12,12 @@ class CSVEmbedding(SourceEmbedding):
def __init__( def __init__(
self, self,
file_path, file_path,
model_name,
vector_store_config, vector_store_config,
embedding_args: Optional[Dict] = None, embedding_args: Optional[Dict] = None,
): ):
"""Initialize with csv path.""" """Initialize with csv path."""
super().__init__(file_path, model_name, vector_store_config) super().__init__(file_path, vector_store_config)
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.embedding_args = embedding_args self.embedding_args = embedding_args

View File

@ -1,30 +1,34 @@
import os from typing import Optional
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
CFG = Config() CFG = Config()
KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
}
class KnowledgeEmbedding: class KnowledgeEmbedding:
def __init__( def __init__(
self, self,
file_path,
model_name, model_name,
vector_store_config, vector_store_config,
local_persist=True, file_type: Optional[str] = "default",
file_type="default", file_path: Optional[str] = None,
): ):
"""Initialize with Loader url, model_name, vector_store_config""" """Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path self.file_path = file_path
@ -33,11 +37,9 @@ class KnowledgeEmbedding:
self.file_type = file_type self.file_type = file_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings self.vector_store_config["embeddings"] = self.embeddings
self.local_persist = local_persist
if not self.local_persist:
self.knowledge_embedding_client = self.init_knowledge_embedding()
def knowledge_embedding(self): def knowledge_embedding(self):
self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding() self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self): def knowledge_embedding_batch(self):
@ -50,95 +52,27 @@ class KnowledgeEmbedding:
model_name=self.model_name, model_name=self.model_name,
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
) )
elif self.file_path.endswith(".pdf"): return embedding
embedding = PDFEmbedding( extension = "." + self.file_path.rsplit(".", 1)[-1]
file_path=self.file_path, if extension in KnowledgeEmbeddingType:
model_name=self.model_name, knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
embedding = knowledge_class(
self.file_path,
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
**knowledge_args,
) )
elif self.file_path.endswith(".md"): return embedding
embedding = MarkdownEmbedding( raise ValueError(f"Unsupported knowledge file type '{extension}'")
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_path.endswith(".csv"):
embedding = CSVEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
elif self.file_type == "default":
embedding = MarkdownEmbedding(
file_path=self.file_path,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
return embedding return embedding
def similar_search(self, text, topk): def similar_search(self, text, topk):
return self.knowledge_embedding_client.similar_search(text, topk) vector_client = VectorStoreConnector(
def vector_exist(self):
return self.knowledge_embedding_client.vector_name_exist()
def knowledge_persist_initialization(self, append_mode):
documents = self._load_knownlege(self.file_path)
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config CFG.VECTOR_STORE_TYPE, self.vector_store_config
) )
self.vector_client.load_document(documents) return vector_client.similar_search(text, topk)
return self.vector_client
def _load_knownlege(self, path): def vector_exist(self):
docments = [] vector_client = VectorStoreConnector(
for root, _, files in os.walk(path, topdown=False): CFG.VECTOR_STORE_TYPE, self.vector_store_config
for file in files: )
filename = os.path.join(root, file) return vector_client.vector_name_exists()
docs = self._load_file(filename)
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(DATASETS_DIR, "")
}
print("doc is embedding...", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
def _load_file(self, filename):
if filename.lower().endswith(".md"):
loader = TextLoader(filename)
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
docs = loader.load_and_split(text_splitter)
i = 0
for d in docs:
content = markdown.markdown(d.page_content)
soup = BeautifulSoup(content, "html.parser")
for tag in soup(["!doctype", "meta", "i.fa"]):
tag.extract()
docs[i].page_content = soup.get_text()
docs[i].page_content = docs[i].page_content.replace("\n", " ")
i += 1
elif filename.lower().endswith(".pdf"):
loader = PyPDFLoader(filename)
textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
)
docs = loader.load_and_split(textsplitter)
i = 0
for d in docs:
docs[i].page_content = d.page_content.replace("\n", " ").replace(
"<EFBFBD>", ""
)
i += 1
else:
loader = TextLoader(filename)
text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE)
docs = loader.load_and_split(text_splitor)
return docs

View File

@ -8,27 +8,30 @@ from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader from langchain.document_loaders import TextLoader
from langchain.schema import Document from langchain.schema import Document
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
class MarkdownEmbedding(SourceEmbedding): class MarkdownEmbedding(SourceEmbedding):
"""markdown embedding for read markdown document.""" """markdown embedding for read markdown document."""
def __init__(self, file_path, model_name, vector_store_config): def __init__(self, file_path, vector_store_config):
"""Initialize with markdown path.""" """Initialize with markdown path."""
super().__init__(file_path, model_name, vector_store_config) super().__init__(file_path, vector_store_config)
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
# self.encoding = encoding
@register @register
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = TextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
text_splitter = CHNDocumentSplitter( text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
) )
return loader.load_and_split(text_splitter) return loader.load_and_split(text_splitter)

View File

@ -5,20 +5,22 @@ from typing import List
from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document from langchain.schema import Document
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
class PDFEmbedding(SourceEmbedding): class PDFEmbedding(SourceEmbedding):
"""pdf embedding for read pdf document.""" """pdf embedding for read pdf document."""
def __init__(self, file_path, model_name, vector_store_config): def __init__(self, file_path, vector_store_config, encoding):
"""Initialize with pdf path.""" """Initialize with pdf path."""
super().__init__(file_path, model_name, vector_store_config) super().__init__(file_path, vector_store_config)
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.encoding = encoding
@register @register
def read(self): def read(self):
@ -26,7 +28,7 @@ class PDFEmbedding(SourceEmbedding):
# loader = UnstructuredPaddlePDFLoader(self.file_path) # loader = UnstructuredPaddlePDFLoader(self.file_path)
loader = PyPDFLoader(self.file_path) loader = PyPDFLoader(self.file_path)
textsplitter = CHNDocumentSplitter( textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
) )
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -23,13 +23,11 @@ class SourceEmbedding(ABC):
def __init__( def __init__(
self, self,
file_path, file_path,
model_name,
vector_store_config, vector_store_config,
embedding_args: Optional[Dict] = None, embedding_args: Optional[Dict] = None,
): ):
"""Initialize with Loader url, model_name, vector_store_config""" """Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.embedding_args = embedding_args self.embedding_args = embedding_args
self.embeddings = vector_store_config["embeddings"] self.embeddings = vector_store_config["embeddings"]

View File

@ -8,11 +8,10 @@ from pilot import SourceEmbedding, register
class StringEmbedding(SourceEmbedding): class StringEmbedding(SourceEmbedding):
"""string embedding for read string document.""" """string embedding for read string document."""
def __init__(self, file_path, model_name, vector_store_config): def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path.""" """Initialize with pdf path."""
super().__init__(file_path, model_name, vector_store_config) super().__init__(file_path, vector_store_config)
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
@register @register

View File

@ -16,11 +16,10 @@ CFG = Config()
class URLEmbedding(SourceEmbedding): class URLEmbedding(SourceEmbedding):
"""url embedding for read url document.""" """url embedding for read url document."""
def __init__(self, file_path, model_name, vector_store_config): def __init__(self, file_path, vector_store_config):
"""Initialize with url path.""" """Initialize with url path."""
super().__init__(file_path, model_name, vector_store_config) super().__init__(file_path, vector_store_config)
self.file_path = file_path self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
@register @register
@ -29,7 +28,7 @@ class URLEmbedding(SourceEmbedding):
loader = WebBaseLoader(web_path=self.file_path) loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en": if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter( text_splitter = CharacterTextSplitter(
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20, chunk_overlap=20,
length_function=len, length_function=len,
) )

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
from langchain.schema import Document
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
class WordEmbedding(SourceEmbedding):
"""word embedding for read word document."""
def __init__(self, file_path, vector_store_config):
"""Initialize with word path."""
super().__init__(file_path, vector_store_config)
self.file_path = file_path
self.vector_store_config = vector_store_config
@register
def read(self):
"""Load from word path."""
loader = UnstructuredWordDocumentLoader(self.file_path)
textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -21,8 +21,10 @@ class DBSummaryClient:
, get_similar_tables method(get user query related tables info) , get_similar_tables method(get user query related tables info)
""" """
@staticmethod def __init__(self):
def db_summary_embedding(dbname): pass
def db_summary_embedding(self, dbname):
"""put db profile and table profile summary into vector store""" """put db profile and table profile summary into vector store"""
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None: if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
db_summary_client = MysqlSummary(dbname) db_summary_client = MysqlSummary(dbname)
@ -34,24 +36,21 @@ class DBSummaryClient:
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
db_summary_client.get_summery(), file_path=db_summary_client.get_summery(),
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config,
vector_store_config,
) )
if not embedding.vector_name_exist(): if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery(): for vector_table_info in db_summary_client.get_summery():
embedding = StringEmbedding( embedding = StringEmbedding(
vector_table_info, vector_table_info,
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config, vector_store_config,
) )
embedding.source_embedding() embedding.source_embedding()
else: else:
embedding = StringEmbedding( embedding = StringEmbedding(
db_summary_client.get_summery(), file_path=db_summary_client.get_summery(),
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config,
vector_store_config,
) )
embedding.source_embedding() embedding.source_embedding()
for ( for (
@ -59,32 +58,24 @@ class DBSummaryClient:
table_summary, table_summary,
) in db_summary_client.get_table_summary().items(): ) in db_summary_client.get_table_summary().items():
table_vector_store_config = { table_vector_store_config = {
"vector_store_name": table_name + "_ts", "vector_store_name": dbname + "_" + table_name + "_ts",
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
table_summary, table_summary,
LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
table_vector_store_config, table_vector_store_config,
) )
embedding.source_embedding() embedding.source_embedding()
logger.info("db summary embedding success") logger.info("db summary embedding success")
@staticmethod def get_similar_tables(self, dbname, query, topk):
def get_similar_tables(dbname, query, topk):
"""get user query related tables info""" """get user query related tables info"""
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
"embeddings": embeddings,
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
@ -104,19 +95,23 @@ class DBSummaryClient:
related_table_summaries = [] related_table_summaries = []
for table in related_tables: for table in related_tables:
vector_store_config = { vector_store_config = {
"vector_store_name": table + "_ts", "vector_store_name": dbname + "_" + table + "_ts",
"embeddings": embeddings,
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path="", file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
local_persist=False,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
table_summery = knowledge_embedding_client.similar_search(query, 1) table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content) related_table_summaries.append(table_summery[0].page_content)
return related_table_summaries return related_table_summaries
def init_db_summary(self):
db = CFG.local_db
dbs = db.get_database_list()
for dbname in dbs:
self.db_summary_embedding(dbname)
def _get_llm_response(query, db_input, dbsummary): def _get_llm_response(query, db_input, dbsummary):
chat_param = { chat_param = {
@ -132,30 +127,3 @@ def _get_llm_response(query, db_input, dbsummary):
) )
res = chat.nostream_call() res = chat.nostream_call()
return json.loads(res)["table"] return json.loads(res)["table"]
# if __name__ == "__main__":
# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10)
#
# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends".
# Response:
#
# {
# "table": ["chat_users"]
# }"""
# text = text.rstrip().replace("\n","")
# start = text.find("{")
# end = text.find("}") + 1
#
# # 从字符串中截取出JSON数据
# json_str = text[start:end]
#
# # 将JSON数据转换为Python中的字典类型
# data = json.loads(json_str)
# # pattern = r'{s*"table"s*:s*[[^]]*]s*}'
# # match = re.search(pattern, text)
# # if match:
# # json_string = match.group(0)
# # # 将JSON字符串转换为Python对象
# # json_obj = json.loads(json_string)
# # print(summary)

View File

@ -17,7 +17,6 @@ from langchain.vectorstores import Chroma
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
VECTOR_SEARCH_TOP_K,
VECTORE_PATH, VECTORE_PATH,
) )
@ -41,7 +40,6 @@ class KnownLedge2Vector:
embeddings: object = None embeddings: object = None
model_name = LLM_MODEL_CONFIG["sentence-transforms"] model_name = LLM_MODEL_CONFIG["sentence-transforms"]
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, model_name=None) -> None: def __init__(self, model_name=None) -> None:
if not model_name: if not model_name:

View File

@ -10,7 +10,6 @@ from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
VECTOR_SEARCH_TOP_K,
) )
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
@ -19,36 +18,30 @@ CFG = Config()
class LocalKnowledgeInit: class LocalKnowledgeInit:
embeddings: object = None embeddings: object = None
model_name = LLM_MODEL_CONFIG["text2vec"]
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self, vector_store_config) -> None: def __init__(self, vector_store_config) -> None:
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.model_name = LLM_MODEL_CONFIG["text2vec"]
def knowledge_persist(self, file_path, append_mode): def knowledge_persist(self, file_path, append_mode):
"""knowledge persist""" """knowledge persist"""
kv = KnowledgeEmbedding( for root, _, files in os.walk(file_path, topdown=False):
file_path=file_path, for file in files:
model_name=LLM_MODEL_CONFIG["text2vec"], filename = os.path.join(root, file)
vector_store_config=self.vector_store_config, # docs = self._load_file(filename)
) ke = KnowledgeEmbedding(
vector_store = kv.knowledge_persist_initialization(append_mode) file_path=filename,
return vector_store model_name=self.model_name,
vector_store_config=self.vector_store_config,
def query(self, q): )
"""Query similar doc from Vector""" client = ke.init_knowledge_embedding()
vector_store = self.init_vector_store() client.source_embedding()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default") parser.add_argument("--vector_name", type=str, default="default")
parser.add_argument("--append", type=bool, default=False) parser.add_argument("--append", type=bool, default=False)
parser.add_argument("--store_type", type=str, default="Chroma")
args = parser.parse_args() args = parser.parse_args()
vector_name = args.vector_name vector_name = args.vector_name
append_mode = args.append append_mode = args.append
@ -56,5 +49,5 @@ if __name__ == "__main__":
vector_store_config = {"vector_store_name": vector_name} vector_store_config = {"vector_store_name": vector_name}
print(vector_store_config) print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
print("your knowledge embedding success...") print("your knowledge embedding success...")