langchain qa base

This commit is contained in:
csunny 2023-05-07 14:47:15 +08:00
parent 63586fc6a3
commit 3c40659527
5 changed files with 66 additions and 8 deletions

4
.gitignore vendored
View File

@ -130,4 +130,6 @@ dmypy.json
# Pyre type checker
.pyre/
.DS_Store
logs
logs
.vectordb

View File

@ -22,6 +22,7 @@ LLM_MODEL_CONFIG = {
}
VECTOR_SEARCH_TOP_K = 5
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048

View File

@ -146,6 +146,16 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
已知内容:
{context}
问题:
{question}
"""
default_conversation = conv_one_shot
conv_templates = {

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qk_prompt_template
from langchain.chains import RetrievalQA
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
class KnownLedgeBaseQA:
llm: object = None
def __init__(self) -> None:
k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store()
def get_answer(self, query):
prompt_template = conv_qk_prompt_template
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
knownledge_chain = RetrievalQA.from_llm(
llm=self.llm,
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
prompt=prompt
)
knownledge_chain.return_source_documents = True
result = knownledge_chain({"query": query})
yield result

View File

@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
from langchain.chains import VectorDBQA
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
VECTOR_SEARCH_TOP_K = 5
class KnownLedge2Vector:
@ -26,21 +25,21 @@ class KnownLedge2Vector:
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
def init_vector_store(self):
documents = self.load_knownlege()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print("向量数据库持久化地址: ", persist_dir)
if os.path.exists(persist_dir):
# 从本地持久化文件中Load
print("从本地向量加载数据...")
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
vector_store.add_documents(documents=documents)
# vector_store.add_documents(documents=documents)
else:
documents = self.load_knownlege()
# 重新初始化
vector_store = Chroma.from_documents(documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir)
vector_store.persist()
vector_store = None
return persist_dir
return vector_store
def load_knownlege(self):
docments = []
@ -70,10 +69,23 @@ class KnownLedge2Vector:
return docs
def _load_from_url(self, url):
"""Load data from url address"""
pass
def query(self, q):
"""Query similar doc from Vector """
vector_store = self.init_vector_store()
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
for doc in docs:
dc, s = doc
yield s, dc
if __name__ == "__main__":
k2v = KnownLedge2Vector()
k2v.init_vector_store()
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
print(persist_dir)
for s, dc in k2v.query("什么是OceanBase"):
print(s, dc.page_content, dc.metadata)