feat:add knowledge reference

This commit is contained in:
aries_ckt 2023-11-01 21:55:24 +08:00
parent be1e1cb160
commit 606d384a55
10 changed files with 122 additions and 66 deletions

View File

@ -187,6 +187,7 @@ class RAGGraphEngine:
triple_results = [] triple_results = []
for doc in docs: for doc in docs:
import threading import threading
thread_id = threading.get_ident() thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task") print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content) triplets = self._extract_triplets(doc.page_content)

View File

@ -143,9 +143,7 @@ class RAGGraphSearch(BaseSearch):
logger.info("> No relationships found, returning nodes found by keywords.") logger.info("> No relationships found, returning nodes found by keywords.")
if len(sorted_nodes_with_scores) == 0: if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.") logger.info("> No nodes found by keywords, returning empty response.")
return [ return [Document(page_content="No relationships found.")]
Document(page_content="No relationships found.")
]
# add relationships as Node # add relationships as Node
# TODO: make initial text customizable # TODO: make initial text customizable

View File

@ -141,7 +141,6 @@ class BaseChat(ABC):
self.current_message.start_date = datetime.datetime.now().strftime( self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )
self.current_message.tokens = 0 self.current_message.tokens = 0
if self.prompt_template.template: if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values) current_prompt = self.prompt_template.format(**input_values)
@ -152,7 +151,6 @@ class BaseChat(ABC):
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict # Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post # fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages)) llm_messages = list(map(lambda m: m.dict(), llm_messages))
payload = { payload = {
"model": self.llm_model, "model": self.llm_model,
"prompt": self.generate_llm_text(), "prompt": self.generate_llm_text(),
@ -167,6 +165,9 @@ class BaseChat(ABC):
def stream_plugin_call(self, text): def stream_plugin_call(self, text):
return text return text
def knowledge_reference_call(self, text):
return text
async def check_iterator_end(iterator): async def check_iterator_end(iterator):
try: try:
await asyncio.anext(iterator) await asyncio.anext(iterator)
@ -196,6 +197,7 @@ class BaseChat(ABC):
view_msg = view_msg.replace("\n", "\\n") view_msg = view_msg.replace("\n", "\\n")
yield view_msg yield view_msg
self.current_message.add_ai_message(msg) self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg) self.current_message.add_view_message(view_msg)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())

View File

@ -21,13 +21,13 @@ class ExtractRefineSummary(BaseChat):
chat_param=chat_param, chat_param=chat_param,
) )
self.user_input = chat_param["current_user_input"] # self.user_input = chat_param["current_user_input"]
self.existing_answer = chat_param["select_param"] self.existing_answer = chat_param["select_param"]
# self.extract_mode = chat_param["select_param"] # self.extract_mode = chat_param["select_param"]
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
"context": self.user_input, # "context": self.user_input,
"existing_answer": self.existing_answer, "existing_answer": self.existing_answer,
} }
return input_values return input_values

View File

@ -3,18 +3,20 @@ from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser from pilot.scene.chat_knowledge.refine_summary.out_parser import (
ExtractRefineSummaryParser,
)
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """""" PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。""" _DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
_DEFAULT_TEMPLATE_EN = """ _DEFAULT_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary. We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
please use original language. \nWhen answering, it is best to summarize according to points 1.2.3.
""" """
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (
@ -29,7 +31,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ExtractRefineSummary.value(), template_scene=ChatScene.ExtractRefineSummary.value(),
input_variables=["existing_answer","context"], input_variables=["existing_answer"],
response_format="", response_format="",
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,

View File

@ -11,14 +11,15 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """""" PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结: _DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
{context} {context}
回答的时候最好按照1.2.3.点进行总结
""" """
_DEFAULT_TEMPLATE_EN = """ _DEFAULT_TEMPLATE_EN = """
Write a concise summary of the following context: Write a summary of the following context:
{context} {context}
please use original language. When answering, it is best to summarize according to points 1.2.3.
""" """
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (

View File

@ -1,5 +1,5 @@
import os import os
from typing import Dict from typing import Dict, List
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
@ -20,7 +20,6 @@ CFG = Config()
class ChatKnowledge(BaseChat): class ChatKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value() chat_scene: str = ChatScene.ChatKnowledge.value()
"""KBQA Chat Module""" """KBQA Chat Module"""
def __init__(self, chat_param: Dict): def __init__(self, chat_param: Dict):
@ -46,7 +45,6 @@ class ChatKnowledge(BaseChat):
if self.space_context is None if self.space_context is None
else int(self.space_context["embedding"]["topk"]) else int(self.space_context["embedding"]["topk"])
) )
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
self.max_token = ( self.max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if self.space_context is None if self.space_context is None
@ -56,11 +54,11 @@ class ChatKnowledge(BaseChat):
"vector_store_name": self.knowledge_space, "vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
} }
from pilot.graph_engine.graph_factory import RAGGraphFactory # from pilot.graph_engine.graph_factory import RAGGraphFactory
#
self.rag_engine = CFG.SYSTEM_APP.get_component( # self.rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create() # ).create()
embedding_factory = CFG.SYSTEM_APP.get_component( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
@ -90,25 +88,29 @@ class ChatKnowledge(BaseChat):
last_output.text = ( last_output.text = (
last_output.text + "\n\nrelations:\n\n" + ",".join(relations) last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
) )
reference = f"\n\n{self.parse_source_view(self.sources)}"
last_output = last_output + reference
yield last_output yield last_output
def knowledge_reference_call(self, text):
"""return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}"
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"] self.prompt_template.template = self.space_context["prompt"]["template"]
# docs = self.knowledge_embedding_client.similar_search(
# self.current_user_input, self.top_k
# )
docs = await blocking_func_to_async( docs = await blocking_func_to_async(
self._executor, self._executor,
self.knowledge_embedding_client.similar_search, self.knowledge_embedding_client.similar_search,
self.current_user_input, self.current_user_input,
self.top_k, self.top_k,
) )
docs = await self.rag_engine.search(query=self.current_user_input) self.sources = self.merge_by_key(
# docs = self.knowledge_embedding_client.similar_search( list(map(lambda doc: doc.metadata, docs)), "source"
# self.current_user_input, self.top_k )
# )
self.current_message.knowledge_source = self.sources
if not docs: if not docs:
raise ValueError( raise ValueError(
"you have no knowledge space, please add your knowledge space" "you have no knowledge space, please add your knowledge space"
@ -125,6 +127,42 @@ class ChatKnowledge(BaseChat):
} }
return input_values return input_values
def parse_source_view(self, sources: List):
html_title = f"##### **References:**"
lines = ""
for item in sources:
source = item["source"] if "source" in item else ""
pages = ",".join(item["pages"]) if "pages" in item else ""
lines += f"{source}"
if len(pages) > 0:
lines += f", **pages**:{pages}\n\n"
else:
lines += "\n\n"
html = f"""{html_title}\n{lines}"""
return html
def merge_by_key(self, data, key):
result = {}
for item in data:
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value() return ChatScene.ChatKnowledge.value()

View File

@ -59,6 +59,8 @@ class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids""" """doc_ids: doc ids"""
doc_ids: List doc_ids: List
model_name: Optional[str] = None
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter. """Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
Preseparator are not included in the vectorized text. Preseparator are not included in the vectorized text.
""" """

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel): class ChunkQueryResponse(BaseModel):
"""data: data""" """data: data"""
data: List = None data: List = None
"""summary: document summary""" """summary: document summary"""
summary: str = None summary: str = None

View File

@ -199,6 +199,7 @@ class KnowledgeService:
# import langchain is very very slow!!! # import langchain is very very slow!!!
doc_ids = sync_request.doc_ids doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
id=doc_id, id=doc_id,
@ -427,11 +428,16 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity - doc: KnowledgeDocumentEntity
""" """
from llama_index import PromptHelper from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL from llama_index.prompts.default_prompt_selectors import (
texts = [doc.page_content for doc in chunk_docs] DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
prompt_helper = PromptHelper(context_window=2500) )
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper(context_window=2000)
texts = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
)
logger.info( logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary" f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
) )
@ -445,13 +451,10 @@ class KnowledgeService:
# print( # print(
# f"refine summary outputs:{summaries}" # f"refine summary outputs:{summaries}"
# ) # )
print( print(f"final summary:{summary}")
f"final summary:{summary}"
)
doc.summary = summary doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc) return knowledge_document_dao.update_knowledge_document(doc)
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db """async document embedding into vector db
Args: Args:
@ -460,11 +463,11 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity - doc: KnowledgeDocumentEntity
""" """
logger.info( logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
) )
try: try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
self.async_document_summary(chunk_docs, doc) self.async_document_summary(chunk_docs, doc)
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success" doc.result = "document embedding success"
if vector_ids is not None: if vector_ids is not None:
@ -526,25 +529,27 @@ class KnowledgeService:
chat_param = { chat_param = {
"chat_session_id": uuid.uuid1(), "chat_session_id": uuid.uuid1(),
"current_user_input": doc, "current_user_input": doc,
"select_param": "summary", "select_param": doc,
"model_name": CFG.LLM_MODEL, "model_name": self.model_name,
} }
from pilot.utils import utils from pilot.common.chat_util import run_async_tasks
loop = utils.get_or_create_event_loop()
summary = loop.run_until_complete( summary_iters = run_async_tasks(
[
llm_chat_response_nostream( llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
) )
]
) )
return summary return summary_iters[0]
def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5): def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5):
"""Extract refine summary by llm""" """Extract refine summary by llm"""
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from pilot.common.chat_util import llm_chat_response_nostream
import uuid import uuid
print(
f"initialize summary is :{summary}" print(f"initialize summary is :{summary}")
)
outputs = [summary] outputs = [summary]
max_iteration = max_iteration if len(docs) > max_iteration else len(docs) max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]: for doc in docs[0:max_iteration]:
@ -552,9 +557,10 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(), "chat_session_id": uuid.uuid1(),
"current_user_input": doc, "current_user_input": doc,
"select_param": summary, "select_param": summary,
"model_name": CFG.LLM_MODEL, "model_name": self.model_name,
} }
from pilot.utils import utils from pilot.utils import utils
loop = utils.get_or_create_event_loop() loop = utils.get_or_create_event_loop()
summary = loop.run_until_complete( summary = loop.run_until_complete(
llm_chat_response_nostream( llm_chat_response_nostream(
@ -562,9 +568,7 @@ class KnowledgeService:
) )
) )
outputs.append(summary) outputs.append(summary)
print( print(f"iterator is {len(outputs)} current summary is :{summary}")
f"iterator is {len(outputs)} current summary is :{summary}"
)
return outputs, summary return outputs, summary
def _mapreduce_extract_summary(self, docs): def _mapreduce_extract_summary(self, docs):
@ -577,6 +581,7 @@ class KnowledgeService:
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from pilot.common.chat_util import llm_chat_response_nostream
import uuid import uuid
tasks = [] tasks = []
max_iteration = 5 max_iteration = 5
if len(docs) == 1: if len(docs) == 1:
@ -589,17 +594,23 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(), "chat_session_id": uuid.uuid1(),
"current_user_input": doc, "current_user_input": doc,
"select_param": "summary", "select_param": "summary",
"model_name": CFG.LLM_MODEL, "model_name": self.model_name,
} }
tasks.append(llm_chat_response_nostream( tasks.append(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)) )
)
from pilot.common.chat_util import run_async_tasks from pilot.common.chat_util import run_async_tasks
summary_iters = run_async_tasks(tasks) summary_iters = run_async_tasks(tasks)
from pilot.common.prompt_util import PromptHelper from pilot.common.prompt_util import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL from llama_index.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
prompt_helper = PromptHelper(context_window=2500) prompt_helper = PromptHelper(context_window=2500)
summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters) summary_iters = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
)
return self._mapreduce_extract_summary(summary_iters) return self._mapreduce_extract_summary(summary_iters)