mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
from pilot.configs.config import Config
|
|
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
|
from pilot.logs import logger
|
|
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
|
from pilot.vector_store.file_loader import KnownLedge2Vector
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class KnownLedgeBaseQA:
|
|
def __init__(self) -> None:
|
|
k2v = KnownLedge2Vector()
|
|
self.vector_store = k2v.init_vector_store()
|
|
self.llm = VicunaLLM()
|
|
|
|
def get_similar_answer(self, query):
|
|
prompt = PromptTemplate(
|
|
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
|
)
|
|
|
|
retriever = self.vector_store.as_retriever(
|
|
search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE}
|
|
)
|
|
docs = retriever.get_relevant_documents(query=query)
|
|
|
|
context = [d.page_content for d in docs]
|
|
result = prompt.format(context="\n".join(context), question=query)
|
|
return result
|
|
|
|
@staticmethod
|
|
def build_knowledge_prompt(query, docs, state):
|
|
prompt_template = PromptTemplate(
|
|
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
|
)
|
|
context = [d.page_content for d in docs]
|
|
result = prompt_template.format(context="\n".join(context), question=query)
|
|
state.messages[-2][1] = result
|
|
prompt = state.get_prompt()
|
|
|
|
if len(prompt) > 4000:
|
|
logger.info("prompt length greater than 4000, rebuild")
|
|
context = context[:2000]
|
|
prompt_template = PromptTemplate(
|
|
template=conv_qa_prompt_template,
|
|
input_variables=["context", "question"],
|
|
)
|
|
result = prompt_template.format(context="\n".join(context), question=query)
|
|
state.messages[-2][1] = result
|
|
prompt = state.get_prompt()
|
|
print("new prompt length:" + str(len(prompt)))
|
|
|
|
return prompt
|
|
|
|
@staticmethod
|
|
def build_db_summary_prompt(query, db_profile_summary, state):
|
|
prompt_template = PromptTemplate(
|
|
template=conv_db_summary_templates,
|
|
input_variables=["db_input", "db_profile_summary"],
|
|
)
|
|
# context = [d.page_content for d in docs]
|
|
result = prompt_template.format(
|
|
db_profile_summary=db_profile_summary, db_input=query
|
|
)
|
|
state.messages[-2][1] = result
|
|
prompt = state.get_prompt()
|
|
return prompt
|