diff --git a/.gitignore b/.gitignore index 4b46e7091..07be74f26 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,6 @@ dmypy.json # Pyre type checker .pyre/ .DS_Store -logs \ No newline at end of file +logs + +.vectordb \ No newline at end of file diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index ca25b3224..149ddb296 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -22,6 +22,7 @@ LLM_MODEL_CONFIG = { } +VECTOR_SEARCH_TOP_K = 5 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 diff --git a/pilot/conversation.py b/pilot/conversation.py index e88ceaccb..688a5c70d 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -146,6 +146,16 @@ conv_vicuna_v1 = Conversation( sep2="", ) + +conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 + 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: + + 已知内容: + {context} + 问题: + {question} +""" + default_conversation = conv_one_shot conv_templates = { diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py new file mode 100644 index 000000000..083ce20cd --- /dev/null +++ b/pilot/server/vectordb_qa.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from pilot.vector_store.file_loader import KnownLedge2Vector +from langchain.prompts import PromptTemplate +from pilot.conversation import conv_qk_prompt_template +from langchain.chains import RetrievalQA +from pilot.configs.model_config import VECTOR_SEARCH_TOP_K + +class KnownLedgeBaseQA: + + llm: object = None + + def __init__(self) -> None: + k2v = KnownLedge2Vector() + self.vector_store = k2v.init_vector_store() + + def get_answer(self, query): + prompt_template = conv_qk_prompt_template + + prompt = PromptTemplate( + template=prompt_template, + input_variables=["context", "question"] + ) + + knownledge_chain = RetrievalQA.from_llm( + llm=self.llm, + retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}), + prompt=prompt + ) + knownledge_chain.return_source_documents = True + result = knownledge_chain({"query": query}) + yield result diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index 269ec05f1..881c8106f 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader from langchain.chains import VectorDBQA from langchain.embeddings import HuggingFaceEmbeddings -from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG +from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K -VECTOR_SEARCH_TOP_K = 5 class KnownLedge2Vector: @@ -26,21 +25,21 @@ class KnownLedge2Vector: self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) def init_vector_store(self): - documents = self.load_knownlege() persist_dir = os.path.join(VECTORE_PATH, ".vectordb") print("向量数据库持久化地址: ", persist_dir) if os.path.exists(persist_dir): # 从本地持久化文件中Load + print("从本地向量加载数据...") vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) - vector_store.add_documents(documents=documents) + # vector_store.add_documents(documents=documents) else: + documents = self.load_knownlege() # 重新初始化 vector_store = Chroma.from_documents(documents=documents, embedding=self.embeddings, persist_directory=persist_dir) vector_store.persist() - vector_store = None - return persist_dir + return vector_store def load_knownlege(self): docments = [] @@ -70,10 +69,23 @@ class KnownLedge2Vector: return docs def _load_from_url(self, url): + """Load data from url address""" pass + + def query(self, q): + """Query similar doc from Vector """ + vector_store = self.init_vector_store() + docs = vector_store.similarity_search_with_score(q, k=self.top_k) + for doc in docs: + dc, s = doc + yield s, dc if __name__ == "__main__": k2v = KnownLedge2Vector() - k2v.init_vector_store() + + persist_dir = os.path.join(VECTORE_PATH, ".vectordb") + print(persist_dir) + for s, dc in k2v.query("什么是OceanBase"): + print(s, dc.page_content, dc.metadata) \ No newline at end of file