update:default knowledge init

This commit is contained in:
aries-ckt 2023-05-25 13:01:30 +08:00
parent d43a849dca
commit d3567fb984
3 changed files with 42 additions and 29 deletions

View File

@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result
@staticmethod
def build_knowledge_prompt(query, docs, state):
prompt_template = PromptTemplate(
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
context = [d.page_content for d in docs]
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
if len(prompt) > 4000:
logger.info("prompt length greater than 4000, rebuild")
context = context[:2000]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
print("new prompt length:" + str(len(prompt)))
return prompt

View File

@ -13,7 +13,8 @@ from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
from pilot.server.vectordb_qa import KnownLedgeBaseQA
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -40,16 +41,13 @@ from pilot.conversation import (
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
logger = build_logger("webserver", LOGDIR + "webserver.log")
@ -263,10 +261,19 @@ def http_bot(
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
@ -285,26 +292,7 @@ def http_bot(
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
print("prompt length:" + str(len(prompt)))
if len(prompt) > 4000:
logger.info("prompt length greater than 4000, rebuild")
context = context[:2000]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
print("new prompt length:" + str(len(prompt)))
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
@ -697,7 +685,7 @@ if __name__ == "__main__":
# 配置初始化
cfg = Config()
dbs = get_database_list()
# dbs = get_database_list()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令

View File

@ -2,7 +2,7 @@ import os
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config