mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 08:47:32 +00:00
update:default knowledge init
This commit is contained in:
parent
d43a849dca
commit
d3567fb984
@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.logs import logger
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def build_knowledge_prompt(query, docs, state):
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
|
||||
if len(prompt) > 4000:
|
||||
logger.info("prompt length greater than 4000, rebuild")
|
||||
context = context[:2000]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
print("new prompt length:" + str(len(prompt)))
|
||||
|
||||
return prompt
|
@ -13,7 +13,8 @@ from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from langchain import PromptTemplate
|
||||
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@ -40,16 +41,13 @@ from pilot.conversation import (
|
||||
)
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger, server_error_msg
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
knownledge_tovec_st,
|
||||
load_knownledge_from_doc,
|
||||
)
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
@ -263,10 +261,19 @@ def http_bot(
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
vector_store_config = {
|
||||
"vector_store_name": "default",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
knqa = KnownLedgeBaseQA()
|
||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
||||
prompt = state.get_prompt()
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
@ -285,26 +292,7 @@ def http_bot(
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
context = [d.page_content for d in docs]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
print("prompt length:" + str(len(prompt)))
|
||||
|
||||
if len(prompt) > 4000:
|
||||
logger.info("prompt length greater than 4000, rebuild")
|
||||
context = context[:2000]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
print("new prompt length:" + str(len(prompt)))
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
@ -697,7 +685,7 @@ if __name__ == "__main__":
|
||||
# 配置初始化
|
||||
cfg = Config()
|
||||
|
||||
dbs = get_database_list()
|
||||
# dbs = get_database_list()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
|
||||
from langchain.document_loaders import PyPDFLoader, TextLoader
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
Loading…
Reference in New Issue
Block a user