mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-29 05:10:56 +00:00
langchain qa base
This commit is contained in:
parent
63586fc6a3
commit
3c40659527
2
.gitignore
vendored
2
.gitignore
vendored
@ -131,3 +131,5 @@ dmypy.json
|
|||||||
.pyre/
|
.pyre/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
logs
|
logs
|
||||||
|
|
||||||
|
.vectordb
|
@ -22,6 +22,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 2048
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
|
@ -146,6 +146,16 @@ conv_vicuna_v1 = Conversation(
|
|||||||
sep2="</s>",
|
sep2="</s>",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||||
|
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||||
|
|
||||||
|
已知内容:
|
||||||
|
{context}
|
||||||
|
问题:
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
|
||||||
default_conversation = conv_one_shot
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
conv_templates = {
|
conv_templates = {
|
||||||
|
33
pilot/server/vectordb_qa.py
Normal file
33
pilot/server/vectordb_qa.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from pilot.conversation import conv_qk_prompt_template
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
class KnownLedgeBaseQA:
|
||||||
|
|
||||||
|
llm: object = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
k2v = KnownLedge2Vector()
|
||||||
|
self.vector_store = k2v.init_vector_store()
|
||||||
|
|
||||||
|
def get_answer(self, query):
|
||||||
|
prompt_template = conv_qk_prompt_template
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=prompt_template,
|
||||||
|
input_variables=["context", "question"]
|
||||||
|
)
|
||||||
|
|
||||||
|
knownledge_chain = RetrievalQA.from_llm(
|
||||||
|
llm=self.llm,
|
||||||
|
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
knownledge_chain.return_source_documents = True
|
||||||
|
result = knownledge_chain({"query": query})
|
||||||
|
yield result
|
@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter
|
|||||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
||||||
from langchain.chains import VectorDBQA
|
from langchain.chains import VectorDBQA
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG
|
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
|
||||||
|
|
||||||
class KnownLedge2Vector:
|
class KnownLedge2Vector:
|
||||||
|
|
||||||
@ -26,21 +25,21 @@ class KnownLedge2Vector:
|
|||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
|
||||||
def init_vector_store(self):
|
def init_vector_store(self):
|
||||||
documents = self.load_knownlege()
|
|
||||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||||
print("向量数据库持久化地址: ", persist_dir)
|
print("向量数据库持久化地址: ", persist_dir)
|
||||||
if os.path.exists(persist_dir):
|
if os.path.exists(persist_dir):
|
||||||
# 从本地持久化文件中Load
|
# 从本地持久化文件中Load
|
||||||
|
print("从本地向量加载数据...")
|
||||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||||
vector_store.add_documents(documents=documents)
|
# vector_store.add_documents(documents=documents)
|
||||||
else:
|
else:
|
||||||
|
documents = self.load_knownlege()
|
||||||
# 重新初始化
|
# 重新初始化
|
||||||
vector_store = Chroma.from_documents(documents=documents,
|
vector_store = Chroma.from_documents(documents=documents,
|
||||||
embedding=self.embeddings,
|
embedding=self.embeddings,
|
||||||
persist_directory=persist_dir)
|
persist_directory=persist_dir)
|
||||||
vector_store.persist()
|
vector_store.persist()
|
||||||
vector_store = None
|
return vector_store
|
||||||
return persist_dir
|
|
||||||
|
|
||||||
def load_knownlege(self):
|
def load_knownlege(self):
|
||||||
docments = []
|
docments = []
|
||||||
@ -70,10 +69,23 @@ class KnownLedge2Vector:
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _load_from_url(self, url):
|
def _load_from_url(self, url):
|
||||||
|
"""Load data from url address"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def query(self, q):
|
||||||
|
"""Query similar doc from Vector """
|
||||||
|
vector_store = self.init_vector_store()
|
||||||
|
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||||
|
for doc in docs:
|
||||||
|
dc, s = doc
|
||||||
|
yield s, dc
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
k2v = KnownLedge2Vector()
|
k2v = KnownLedge2Vector()
|
||||||
k2v.init_vector_store()
|
|
||||||
|
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||||
|
print(persist_dir)
|
||||||
|
for s, dc in k2v.query("什么是OceanBase"):
|
||||||
|
print(s, dc.page_content, dc.metadata)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user