mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-21 17:54:58 +00:00
feat:add knowledge reference
This commit is contained in:
parent
be1e1cb160
commit
606d384a55
@ -187,6 +187,7 @@ class RAGGraphEngine:
|
||||
triple_results = []
|
||||
for doc in docs:
|
||||
import threading
|
||||
|
||||
thread_id = threading.get_ident()
|
||||
print(f"current thread-{thread_id} begin extract triplets task")
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
|
@ -143,9 +143,7 @@ class RAGGraphSearch(BaseSearch):
|
||||
logger.info("> No relationships found, returning nodes found by keywords.")
|
||||
if len(sorted_nodes_with_scores) == 0:
|
||||
logger.info("> No nodes found by keywords, returning empty response.")
|
||||
return [
|
||||
Document(page_content="No relationships found.")
|
||||
]
|
||||
return [Document(page_content="No relationships found.")]
|
||||
|
||||
# add relationships as Node
|
||||
# TODO: make initial text customizable
|
||||
|
@ -141,7 +141,6 @@ class BaseChat(ABC):
|
||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
self.current_message.tokens = 0
|
||||
if self.prompt_template.template:
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
@ -152,7 +151,6 @@ class BaseChat(ABC):
|
||||
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
|
||||
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
|
||||
llm_messages = list(map(lambda m: m.dict(), llm_messages))
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": self.generate_llm_text(),
|
||||
@ -167,6 +165,9 @@ class BaseChat(ABC):
|
||||
def stream_plugin_call(self, text):
|
||||
return text
|
||||
|
||||
def knowledge_reference_call(self, text):
|
||||
return text
|
||||
|
||||
async def check_iterator_end(iterator):
|
||||
try:
|
||||
await asyncio.anext(iterator)
|
||||
@ -196,6 +197,7 @@ class BaseChat(ABC):
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
yield view_msg
|
||||
self.current_message.add_ai_message(msg)
|
||||
view_msg = self.knowledge_reference_call(msg)
|
||||
self.current_message.add_view_message(view_msg)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
@ -21,13 +21,13 @@ class ExtractRefineSummary(BaseChat):
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
# self.user_input = chat_param["current_user_input"]
|
||||
self.existing_answer = chat_param["select_param"]
|
||||
# self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"context": self.user_input,
|
||||
# "context": self.user_input,
|
||||
"existing_answer": self.existing_answer,
|
||||
}
|
||||
return input_values
|
||||
|
@ -3,18 +3,20 @@ from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser
|
||||
from pilot.scene.chat_knowledge.refine_summary.out_parser import (
|
||||
ExtractRefineSummaryParser,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。"""
|
||||
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary.
|
||||
please use original language.
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
|
||||
\nWhen answering, it is best to summarize according to points 1.2.3.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
@ -29,7 +31,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractRefineSummary.value(),
|
||||
input_variables=["existing_answer","context"],
|
||||
input_variables=["existing_answer"],
|
||||
response_format="",
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
|
@ -11,14 +11,15 @@ CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结:
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
|
||||
{context}
|
||||
回答的时候最好按照1.2.3.点进行总结
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Write a concise summary of the following context:
|
||||
Write a summary of the following context:
|
||||
{context}
|
||||
please use original language.
|
||||
When answering, it is best to summarize according to points 1.2.3.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from pilot.component import ComponentType
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
@ -20,7 +20,6 @@ CFG = Config()
|
||||
|
||||
class ChatKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatKnowledge.value()
|
||||
|
||||
"""KBQA Chat Module"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
@ -46,7 +45,6 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context is None
|
||||
else int(self.space_context["embedding"]["topk"])
|
||||
)
|
||||
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
|
||||
self.max_token = (
|
||||
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||
if self.space_context is None
|
||||
@ -56,11 +54,11 @@ class ChatKnowledge(BaseChat):
|
||||
"vector_store_name": self.knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
|
||||
self.rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
#
|
||||
# self.rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
# ).create()
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
@ -90,25 +88,29 @@ class ChatKnowledge(BaseChat):
|
||||
last_output.text = (
|
||||
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
|
||||
)
|
||||
yield last_output
|
||||
reference = f"\n\n{self.parse_source_view(self.sources)}"
|
||||
last_output = last_output + reference
|
||||
yield last_output
|
||||
|
||||
def knowledge_reference_call(self, text):
|
||||
"""return reference"""
|
||||
return text + f"\n\n{self.parse_source_view(self.sources)}"
|
||||
|
||||
async def generate_input_values(self) -> Dict:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
# docs = self.knowledge_embedding_client.similar_search(
|
||||
# self.current_user_input, self.top_k
|
||||
# )
|
||||
docs = await blocking_func_to_async(
|
||||
self._executor,
|
||||
self.knowledge_embedding_client.similar_search,
|
||||
self.current_user_input,
|
||||
self.top_k,
|
||||
)
|
||||
docs = await self.rag_engine.search(query=self.current_user_input)
|
||||
# docs = self.knowledge_embedding_client.similar_search(
|
||||
# self.current_user_input, self.top_k
|
||||
# )
|
||||
self.sources = self.merge_by_key(
|
||||
list(map(lambda doc: doc.metadata, docs)), "source"
|
||||
)
|
||||
|
||||
self.current_message.knowledge_source = self.sources
|
||||
if not docs:
|
||||
raise ValueError(
|
||||
"you have no knowledge space, please add your knowledge space"
|
||||
@ -125,6 +127,42 @@ class ChatKnowledge(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
def parse_source_view(self, sources: List):
|
||||
html_title = f"##### **References:**"
|
||||
lines = ""
|
||||
for item in sources:
|
||||
source = item["source"] if "source" in item else ""
|
||||
pages = ",".join(item["pages"]) if "pages" in item else ""
|
||||
lines += f"{source}"
|
||||
if len(pages) > 0:
|
||||
lines += f", **pages**:{pages}\n\n"
|
||||
else:
|
||||
lines += "\n\n"
|
||||
html = f"""{html_title}\n{lines}"""
|
||||
return html
|
||||
|
||||
def merge_by_key(self, data, key):
|
||||
result = {}
|
||||
for item in data:
|
||||
item_key = os.path.basename(item.get(key))
|
||||
if item_key in result:
|
||||
if "pages" in result[item_key] and "page" in item:
|
||||
result[item_key]["pages"].append(str(item["page"]))
|
||||
elif "page" in item:
|
||||
result[item_key]["pages"] = [
|
||||
result[item_key]["pages"],
|
||||
str(item["page"]),
|
||||
]
|
||||
else:
|
||||
if "page" in item:
|
||||
result[item_key] = {
|
||||
"source": item_key,
|
||||
"pages": [str(item["page"])],
|
||||
}
|
||||
else:
|
||||
result[item_key] = {"source": item_key}
|
||||
return list(result.values())
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value()
|
||||
|
@ -59,6 +59,8 @@ class DocumentSyncRequest(BaseModel):
|
||||
"""doc_ids: doc ids"""
|
||||
doc_ids: List
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
|
||||
Preseparator are not included in the vectorized text.
|
||||
"""
|
||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
class ChunkQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""summary: document summary"""
|
||||
summary: str = None
|
||||
|
@ -199,6 +199,7 @@ class KnowledgeService:
|
||||
# import langchain is very very slow!!!
|
||||
|
||||
doc_ids = sync_request.doc_ids
|
||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
id=doc_id,
|
||||
@ -427,11 +428,16 @@ class KnowledgeService:
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
from llama_index import PromptHelper
|
||||
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
|
||||
texts = [doc.page_content for doc in chunk_docs]
|
||||
prompt_helper = PromptHelper(context_window=2500)
|
||||
from llama_index.prompts.default_prompt_selectors import (
|
||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
||||
)
|
||||
|
||||
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
|
||||
texts = [doc.page_content for doc in chunk_docs]
|
||||
prompt_helper = PromptHelper(context_window=2000)
|
||||
|
||||
texts = prompt_helper.repack(
|
||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
|
||||
)
|
||||
logger.info(
|
||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||
)
|
||||
@ -445,13 +451,10 @@ class KnowledgeService:
|
||||
# print(
|
||||
# f"refine summary outputs:{summaries}"
|
||||
# )
|
||||
print(
|
||||
f"final summary:{summary}"
|
||||
)
|
||||
print(f"final summary:{summary}")
|
||||
doc.summary = summary
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
@ -460,11 +463,11 @@ class KnowledgeService:
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
logger.info(
|
||||
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
self.async_document_summary(chunk_docs, doc)
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
if vector_ids is not None:
|
||||
@ -526,25 +529,27 @@ class KnowledgeService:
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"select_param": "summary",
|
||||
"model_name": CFG.LLM_MODEL,
|
||||
"select_param": doc,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
from pilot.utils import utils
|
||||
loop = utils.get_or_create_event_loop()
|
||||
summary = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
from pilot.common.chat_util import run_async_tasks
|
||||
|
||||
summary_iters = run_async_tasks(
|
||||
[
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
]
|
||||
)
|
||||
return summary
|
||||
def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
|
||||
return summary_iters[0]
|
||||
|
||||
def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5):
|
||||
"""Extract refine summary by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
print(
|
||||
f"initialize summary is :{summary}"
|
||||
)
|
||||
|
||||
print(f"initialize summary is :{summary}")
|
||||
outputs = [summary]
|
||||
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
||||
for doc in docs[0:max_iteration]:
|
||||
@ -552,9 +557,10 @@ class KnowledgeService:
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"select_param": summary,
|
||||
"model_name": CFG.LLM_MODEL,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
from pilot.utils import utils
|
||||
|
||||
loop = utils.get_or_create_event_loop()
|
||||
summary = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
@ -562,9 +568,7 @@ class KnowledgeService:
|
||||
)
|
||||
)
|
||||
outputs.append(summary)
|
||||
print(
|
||||
f"iterator is {len(outputs)} current summary is :{summary}"
|
||||
)
|
||||
print(f"iterator is {len(outputs)} current summary is :{summary}")
|
||||
return outputs, summary
|
||||
|
||||
def _mapreduce_extract_summary(self, docs):
|
||||
@ -577,6 +581,7 @@ class KnowledgeService:
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
tasks = []
|
||||
max_iteration = 5
|
||||
if len(docs) == 1:
|
||||
@ -589,17 +594,23 @@ class KnowledgeService:
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"select_param": "summary",
|
||||
"model_name": CFG.LLM_MODEL,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
tasks.append(llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
))
|
||||
tasks.append(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
from pilot.common.chat_util import run_async_tasks
|
||||
|
||||
summary_iters = run_async_tasks(tasks)
|
||||
from pilot.common.prompt_util import PromptHelper
|
||||
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
|
||||
from llama_index.prompts.default_prompt_selectors import (
|
||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
||||
)
|
||||
|
||||
prompt_helper = PromptHelper(context_window=2500)
|
||||
summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters)
|
||||
summary_iters = prompt_helper.repack(
|
||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
|
||||
)
|
||||
return self._mapreduce_extract_summary(summary_iters)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user