feat:add knowledge reference

This commit is contained in:
aries_ckt 2023-11-01 21:55:24 +08:00
parent be1e1cb160
commit 606d384a55
10 changed files with 122 additions and 66 deletions

View File

@ -187,6 +187,7 @@ class RAGGraphEngine:
triple_results = []
for doc in docs:
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)

View File

@ -143,9 +143,7 @@ class RAGGraphSearch(BaseSearch):
logger.info("> No relationships found, returning nodes found by keywords.")
if len(sorted_nodes_with_scores) == 0:
logger.info("> No nodes found by keywords, returning empty response.")
return [
Document(page_content="No relationships found.")
]
return [Document(page_content="No relationships found.")]
# add relationships as Node
# TODO: make initial text customizable

View File

@ -141,7 +141,6 @@ class BaseChat(ABC):
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
@ -152,7 +151,6 @@ class BaseChat(ABC):
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages))
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
@ -167,6 +165,9 @@ class BaseChat(ABC):
def stream_plugin_call(self, text):
return text
def knowledge_reference_call(self, text):
return text
async def check_iterator_end(iterator):
try:
await asyncio.anext(iterator)
@ -196,6 +197,7 @@ class BaseChat(ABC):
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg)
except Exception as e:
print(traceback.format_exc())

View File

@ -21,13 +21,13 @@ class ExtractRefineSummary(BaseChat):
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
# self.user_input = chat_param["current_user_input"]
self.existing_answer = chat_param["select_param"]
# self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
input_values = {
"context": self.user_input,
# "context": self.user_input,
"existing_answer": self.existing_answer,
}
return input_values

View File

@ -3,18 +3,20 @@ from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser
from pilot.scene.chat_knowledge.refine_summary.out_parser import (
ExtractRefineSummaryParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。"""
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
_DEFAULT_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary.
please use original language.
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
\nWhen answering, it is best to summarize according to points 1.2.3.
"""
_DEFAULT_TEMPLATE = (
@ -29,7 +31,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractRefineSummary.value(),
input_variables=["existing_answer","context"],
input_variables=["existing_answer"],
response_format="",
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,

View File

@ -11,14 +11,15 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结:
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
{context}
回答的时候最好按照1.2.3.点进行总结
"""
_DEFAULT_TEMPLATE_EN = """
Write a concise summary of the following context:
Write a summary of the following context:
{context}
please use original language.
When answering, it is best to summarize according to points 1.2.3.
"""
_DEFAULT_TEMPLATE = (

View File

@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, List
from pilot.component import ComponentType
from pilot.scene.base_chat import BaseChat
@ -20,7 +20,6 @@ CFG = Config()
class ChatKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value()
"""KBQA Chat Module"""
def __init__(self, chat_param: Dict):
@ -46,7 +45,6 @@ class ChatKnowledge(BaseChat):
if self.space_context is None
else int(self.space_context["embedding"]["topk"])
)
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
self.max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if self.space_context is None
@ -56,11 +54,11 @@ class ChatKnowledge(BaseChat):
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
from pilot.graph_engine.graph_factory import RAGGraphFactory
self.rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
# from pilot.graph_engine.graph_factory import RAGGraphFactory
#
# self.rag_engine = CFG.SYSTEM_APP.get_component(
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
# ).create()
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@ -90,25 +88,29 @@ class ChatKnowledge(BaseChat):
last_output.text = (
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
)
yield last_output
reference = f"\n\n{self.parse_source_view(self.sources)}"
last_output = last_output + reference
yield last_output
def knowledge_reference_call(self, text):
"""return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}"
async def generate_input_values(self) -> Dict:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
# docs = self.knowledge_embedding_client.similar_search(
# self.current_user_input, self.top_k
# )
docs = await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search,
self.current_user_input,
self.top_k,
)
docs = await self.rag_engine.search(query=self.current_user_input)
# docs = self.knowledge_embedding_client.similar_search(
# self.current_user_input, self.top_k
# )
self.sources = self.merge_by_key(
list(map(lambda doc: doc.metadata, docs)), "source"
)
self.current_message.knowledge_source = self.sources
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"
@ -125,6 +127,42 @@ class ChatKnowledge(BaseChat):
}
return input_values
def parse_source_view(self, sources: List):
html_title = f"##### **References:**"
lines = ""
for item in sources:
source = item["source"] if "source" in item else ""
pages = ",".join(item["pages"]) if "pages" in item else ""
lines += f"{source}"
if len(pages) > 0:
lines += f", **pages**:{pages}\n\n"
else:
lines += "\n\n"
html = f"""{html_title}\n{lines}"""
return html
def merge_by_key(self, data, key):
result = {}
for item in data:
item_key = os.path.basename(item.get(key))
if item_key in result:
if "pages" in result[item_key] and "page" in item:
result[item_key]["pages"].append(str(item["page"]))
elif "page" in item:
result[item_key]["pages"] = [
result[item_key]["pages"],
str(item["page"]),
]
else:
if "page" in item:
result[item_key] = {
"source": item_key,
"pages": [str(item["page"])],
}
else:
result[item_key] = {"source": item_key}
return list(result.values())
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()

View File

@ -59,6 +59,8 @@ class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
doc_ids: List
model_name: Optional[str] = None
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
Preseparator are not included in the vectorized text.
"""

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""summary: document summary"""
summary: str = None

View File

@ -199,6 +199,7 @@ class KnowledgeService:
# import langchain is very very slow!!!
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
@ -427,11 +428,16 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity
"""
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper(context_window=2500)
from llama_index.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts)
texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper(context_window=2000)
texts = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
@ -445,13 +451,10 @@ class KnowledgeService:
# print(
# f"refine summary outputs:{summaries}"
# )
print(
f"final summary:{summary}"
)
print(f"final summary:{summary}")
doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
Args:
@ -460,11 +463,11 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity
"""
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
self.async_document_summary(chunk_docs, doc)
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
@ -526,25 +529,27 @@ class KnowledgeService:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summary",
"model_name": CFG.LLM_MODEL,
"select_param": doc,
"model_name": self.model_name,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
summary = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
from pilot.common.chat_util import run_async_tasks
summary_iters = run_async_tasks(
[
llm_chat_response_nostream(
ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
)
]
)
return summary
def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5):
return summary_iters[0]
def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5):
"""Extract refine summary by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
print(
f"initialize summary is :{summary}"
)
print(f"initialize summary is :{summary}")
outputs = [summary]
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
@ -552,9 +557,10 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": summary,
"model_name": CFG.LLM_MODEL,
"model_name": self.model_name,
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
summary = loop.run_until_complete(
llm_chat_response_nostream(
@ -562,9 +568,7 @@ class KnowledgeService:
)
)
outputs.append(summary)
print(
f"iterator is {len(outputs)} current summary is :{summary}"
)
print(f"iterator is {len(outputs)} current summary is :{summary}")
return outputs, summary
def _mapreduce_extract_summary(self, docs):
@ -577,6 +581,7 @@ class KnowledgeService:
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
tasks = []
max_iteration = 5
if len(docs) == 1:
@ -589,17 +594,23 @@ class KnowledgeService:
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summary",
"model_name": CFG.LLM_MODEL,
"model_name": self.model_name,
}
tasks.append(llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
))
tasks.append(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
from pilot.common.chat_util import run_async_tasks
summary_iters = run_async_tasks(tasks)
from pilot.common.prompt_util import PromptHelper
from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
from llama_index.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
prompt_helper = PromptHelper(context_window=2500)
summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters)
summary_iters = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
)
return self._mapreduce_extract_summary(summary_iters)