mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
update:default knowledge init
This commit is contained in:
parent
d43a849dca
commit
d3567fb984
@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
|
|||||||
|
|
||||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
from pilot.conversation import conv_qa_prompt_template
|
from pilot.conversation import conv_qa_prompt_template
|
||||||
|
from pilot.logs import logger
|
||||||
from pilot.model.vicuna_llm import VicunaLLM
|
from pilot.model.vicuna_llm import VicunaLLM
|
||||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||||
|
|
||||||
@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
|
|||||||
context = [d.page_content for d in docs]
|
context = [d.page_content for d in docs]
|
||||||
result = prompt.format(context="\n".join(context), question=query)
|
result = prompt.format(context="\n".join(context), question=query)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_knowledge_prompt(query, docs, state):
|
||||||
|
prompt_template = PromptTemplate(
|
||||||
|
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||||
|
)
|
||||||
|
context = [d.page_content for d in docs]
|
||||||
|
result = prompt_template.format(context="\n".join(context), question=query)
|
||||||
|
state.messages[-2][1] = result
|
||||||
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
|
if len(prompt) > 4000:
|
||||||
|
logger.info("prompt length greater than 4000, rebuild")
|
||||||
|
context = context[:2000]
|
||||||
|
prompt_template = PromptTemplate(
|
||||||
|
template=conv_qa_prompt_template,
|
||||||
|
input_variables=["context", "question"],
|
||||||
|
)
|
||||||
|
result = prompt_template.format(context="\n".join(context), question=query)
|
||||||
|
state.messages[-2][1] = result
|
||||||
|
prompt = state.get_prompt()
|
||||||
|
print("new prompt length:" + str(len(prompt)))
|
||||||
|
|
||||||
|
return prompt
|
@ -13,7 +13,8 @@ from urllib.parse import urljoin
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import requests
|
import requests
|
||||||
from langchain import PromptTemplate
|
|
||||||
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
@ -40,16 +41,13 @@ from pilot.conversation import (
|
|||||||
)
|
)
|
||||||
from pilot.plugins import scan_plugins
|
from pilot.plugins import scan_plugins
|
||||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||||
from pilot.prompts.generator import PromptGenerator
|
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
from pilot.utils import build_logger, server_error_msg
|
from pilot.utils import build_logger, server_error_msg
|
||||||
from pilot.vector_store.extract_tovec import (
|
from pilot.vector_store.extract_tovec import (
|
||||||
get_vector_storelist,
|
get_vector_storelist,
|
||||||
knownledge_tovec_st,
|
knownledge_tovec_st,
|
||||||
load_knownledge_from_doc,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
@ -263,10 +261,19 @@ def http_bot(
|
|||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": "default",
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
}
|
||||||
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
file_path="",
|
||||||
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
|
local_persist=False,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
knqa = KnownLedgeBaseQA()
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
prompt = state.get_prompt()
|
|
||||||
state.messages[-2][1] = query
|
state.messages[-2][1] = query
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
@ -285,26 +292,7 @@ def http_bot(
|
|||||||
)
|
)
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
context = [d.page_content for d in docs]
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
|
||||||
)
|
|
||||||
result = prompt_template.format(context="\n".join(context), question=query)
|
|
||||||
state.messages[-2][1] = result
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
print("prompt length:" + str(len(prompt)))
|
|
||||||
|
|
||||||
if len(prompt) > 4000:
|
|
||||||
logger.info("prompt length greater than 4000, rebuild")
|
|
||||||
context = context[:2000]
|
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
template=conv_qa_prompt_template,
|
|
||||||
input_variables=["context", "question"],
|
|
||||||
)
|
|
||||||
result = prompt_template.format(context="\n".join(context), question=query)
|
|
||||||
state.messages[-2][1] = result
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
print("new prompt length:" + str(len(prompt)))
|
|
||||||
|
|
||||||
state.messages[-2][1] = query
|
state.messages[-2][1] = query
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
@ -697,7 +685,7 @@ if __name__ == "__main__":
|
|||||||
# 配置初始化
|
# 配置初始化
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
dbs = get_database_list()
|
# dbs = get_database_list()
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
|
from langchain.document_loaders import PyPDFLoader, TextLoader
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
Loading…
Reference in New Issue
Block a user