diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py index 491a8625c..bea5f3123 100644 --- a/pilot/graph_engine/graph_engine.py +++ b/pilot/graph_engine/graph_engine.py @@ -187,6 +187,7 @@ class RAGGraphEngine: triple_results = [] for doc in docs: import threading + thread_id = threading.get_ident() print(f"current thread-{thread_id} begin extract triplets task") triplets = self._extract_triplets(doc.page_content) diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py index f3025be85..9419a4979 100644 --- a/pilot/graph_engine/graph_search.py +++ b/pilot/graph_engine/graph_search.py @@ -143,9 +143,7 @@ class RAGGraphSearch(BaseSearch): logger.info("> No relationships found, returning nodes found by keywords.") if len(sorted_nodes_with_scores) == 0: logger.info("> No nodes found by keywords, returning empty response.") - return [ - Document(page_content="No relationships found.") - ] + return [Document(page_content="No relationships found.")] # add relationships as Node # TODO: make initial text customizable diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 5c33f4770..34f294c31 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -141,7 +141,6 @@ class BaseChat(ABC): self.current_message.start_date = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) - self.current_message.tokens = 0 if self.prompt_template.template: current_prompt = self.prompt_template.format(**input_values) @@ -152,7 +151,6 @@ class BaseChat(ABC): # Not new server mode, we convert the message format(List[ModelMessage]) to list of dict # fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post llm_messages = list(map(lambda m: m.dict(), llm_messages)) - payload = { "model": self.llm_model, "prompt": self.generate_llm_text(), @@ -167,6 +165,9 @@ class BaseChat(ABC): def stream_plugin_call(self, text): return text + def knowledge_reference_call(self, text): + return text + async def check_iterator_end(iterator): try: await asyncio.anext(iterator) @@ -196,6 +197,7 @@ class BaseChat(ABC): view_msg = view_msg.replace("\n", "\\n") yield view_msg self.current_message.add_ai_message(msg) + view_msg = self.knowledge_reference_call(msg) self.current_message.add_view_message(view_msg) except Exception as e: print(traceback.format_exc()) diff --git a/pilot/scene/chat_knowledge/refine_summary/chat.py b/pilot/scene/chat_knowledge/refine_summary/chat.py index b3a934dd5..d0b1e9471 100644 --- a/pilot/scene/chat_knowledge/refine_summary/chat.py +++ b/pilot/scene/chat_knowledge/refine_summary/chat.py @@ -21,13 +21,13 @@ class ExtractRefineSummary(BaseChat): chat_param=chat_param, ) - self.user_input = chat_param["current_user_input"] + # self.user_input = chat_param["current_user_input"] self.existing_answer = chat_param["select_param"] # self.extract_mode = chat_param["select_param"] def generate_input_values(self): input_values = { - "context": self.user_input, + # "context": self.user_input, "existing_answer": self.existing_answer, } return input_values diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py index 69d4e46df..cd5087f35 100644 --- a/pilot/scene/chat_knowledge/refine_summary/prompt.py +++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py @@ -3,18 +3,20 @@ from pilot.configs.config import Config from pilot.scene.base import ChatScene from pilot.common.schema import SeparatorStyle -from pilot.scene.chat_knowledge.refine_summary.out_parser import ExtractRefineSummaryParser +from pilot.scene.chat_knowledge.refine_summary.out_parser import ( + ExtractRefineSummaryParser, +) CFG = Config() PROMPT_SCENE_DEFINE = """""" -_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 我们有机会在下面提供的更多上下文信息的基础上进一步完善现有的总结(仅在需要的情况下)。请根据新的上下文信息,完善原来的总结。\n------------\n{context}\n------------\n如果上下文信息没有用处,请返回原来的总结。""" +_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结""" _DEFAULT_TEMPLATE_EN = """ -We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.\n------------\n{context}\n------------\nGiven the new context, refine the original summary. \nIf the context isn't useful, return the original summary. -please use original language. +We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary. +\nWhen answering, it is best to summarize according to points 1.2.3. """ _DEFAULT_TEMPLATE = ( @@ -29,7 +31,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False prompt = PromptTemplate( template_scene=ChatScene.ExtractRefineSummary.value(), - input_variables=["existing_answer","context"], + input_variables=["existing_answer"], response_format="", template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE, diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py index ec7c05c32..10a239586 100644 --- a/pilot/scene/chat_knowledge/summary/prompt.py +++ b/pilot/scene/chat_knowledge/summary/prompt.py @@ -11,14 +11,15 @@ CFG = Config() PROMPT_SCENE_DEFINE = """""" -_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行简洁地总结: +_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结: {context} +回答的时候最好按照1.2.3.点进行总结 """ _DEFAULT_TEMPLATE_EN = """ -Write a concise summary of the following context: +Write a summary of the following context: {context} -please use original language. +When answering, it is best to summarize according to points 1.2.3. """ _DEFAULT_TEMPLATE = ( diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 576a93872..e6c3d9056 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -1,5 +1,5 @@ import os -from typing import Dict +from typing import Dict, List from pilot.component import ComponentType from pilot.scene.base_chat import BaseChat @@ -20,7 +20,6 @@ CFG = Config() class ChatKnowledge(BaseChat): chat_scene: str = ChatScene.ChatKnowledge.value() - """KBQA Chat Module""" def __init__(self, chat_param: Dict): @@ -46,7 +45,6 @@ class ChatKnowledge(BaseChat): if self.space_context is None else int(self.space_context["embedding"]["topk"]) ) - # self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"] self.max_token = ( CFG.KNOWLEDGE_SEARCH_MAX_TOKEN if self.space_context is None @@ -56,11 +54,11 @@ class ChatKnowledge(BaseChat): "vector_store_name": self.knowledge_space, "vector_store_type": CFG.VECTOR_STORE_TYPE, } - from pilot.graph_engine.graph_factory import RAGGraphFactory - - self.rag_engine = CFG.SYSTEM_APP.get_component( - ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory - ).create() + # from pilot.graph_engine.graph_factory import RAGGraphFactory + # + # self.rag_engine = CFG.SYSTEM_APP.get_component( + # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + # ).create() embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) @@ -90,25 +88,29 @@ class ChatKnowledge(BaseChat): last_output.text = ( last_output.text + "\n\nrelations:\n\n" + ",".join(relations) ) - yield last_output + reference = f"\n\n{self.parse_source_view(self.sources)}" + last_output = last_output + reference + yield last_output + + def knowledge_reference_call(self, text): + """return reference""" + return text + f"\n\n{self.parse_source_view(self.sources)}" async def generate_input_values(self) -> Dict: if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template = self.space_context["prompt"]["template"] - # docs = self.knowledge_embedding_client.similar_search( - # self.current_user_input, self.top_k - # ) docs = await blocking_func_to_async( self._executor, self.knowledge_embedding_client.similar_search, self.current_user_input, self.top_k, ) - docs = await self.rag_engine.search(query=self.current_user_input) - # docs = self.knowledge_embedding_client.similar_search( - # self.current_user_input, self.top_k - # ) + self.sources = self.merge_by_key( + list(map(lambda doc: doc.metadata, docs)), "source" + ) + + self.current_message.knowledge_source = self.sources if not docs: raise ValueError( "you have no knowledge space, please add your knowledge space" @@ -125,6 +127,42 @@ class ChatKnowledge(BaseChat): } return input_values + def parse_source_view(self, sources: List): + html_title = f"##### **References:**" + lines = "" + for item in sources: + source = item["source"] if "source" in item else "" + pages = ",".join(item["pages"]) if "pages" in item else "" + lines += f"{source}" + if len(pages) > 0: + lines += f", **pages**:{pages}\n\n" + else: + lines += "\n\n" + html = f"""{html_title}\n{lines}""" + return html + + def merge_by_key(self, data, key): + result = {} + for item in data: + item_key = os.path.basename(item.get(key)) + if item_key in result: + if "pages" in result[item_key] and "page" in item: + result[item_key]["pages"].append(str(item["page"])) + elif "page" in item: + result[item_key]["pages"] = [ + result[item_key]["pages"], + str(item["page"]), + ] + else: + if "page" in item: + result[item_key] = { + "source": item_key, + "pages": [str(item["page"])], + } + else: + result[item_key] = {"source": item_key} + return list(result.values()) + @property def chat_type(self) -> str: return ChatScene.ChatKnowledge.value() diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py index c6b94ff0d..032b97ba1 100644 --- a/pilot/server/knowledge/request/request.py +++ b/pilot/server/knowledge/request/request.py @@ -59,6 +59,8 @@ class DocumentSyncRequest(BaseModel): """doc_ids: doc ids""" doc_ids: List + model_name: Optional[str] = None + """Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter. Preseparator are not included in the vectorized text. """ diff --git a/pilot/server/knowledge/request/response.py b/pilot/server/knowledge/request/response.py index 2e3e5f0ab..5c1c7efd1 100644 --- a/pilot/server/knowledge/request/response.py +++ b/pilot/server/knowledge/request/response.py @@ -5,6 +5,7 @@ from pydantic import BaseModel class ChunkQueryResponse(BaseModel): """data: data""" + data: List = None """summary: document summary""" summary: str = None diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 7e899ba78..1dece9054 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -199,6 +199,7 @@ class KnowledgeService: # import langchain is very very slow!!! doc_ids = sync_request.doc_ids + self.model_name = sync_request.model_name or CFG.LLM_MODEL for doc_id in doc_ids: query = KnowledgeDocumentEntity( id=doc_id, @@ -427,11 +428,16 @@ class KnowledgeService: - doc: KnowledgeDocumentEntity """ from llama_index import PromptHelper - from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL - texts = [doc.page_content for doc in chunk_docs] - prompt_helper = PromptHelper(context_window=2500) + from llama_index.prompts.default_prompt_selectors import ( + DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, + ) - texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) + texts = [doc.page_content for doc in chunk_docs] + prompt_helper = PromptHelper(context_window=2000) + + texts = prompt_helper.repack( + prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts + ) logger.info( f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary" ) @@ -445,13 +451,10 @@ class KnowledgeService: # print( # f"refine summary outputs:{summaries}" # ) - print( - f"final summary:{summary}" - ) + print(f"final summary:{summary}") doc.summary = summary return knowledge_document_dao.update_knowledge_document(doc) - def async_doc_embedding(self, client, chunk_docs, doc): """async document embedding into vector db Args: @@ -460,11 +463,11 @@ class KnowledgeService: - doc: KnowledgeDocumentEntity """ logger.info( - f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" + f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" ) try: - vector_ids = client.knowledge_embedding_batch(chunk_docs) self.async_document_summary(chunk_docs, doc) + vector_ids = client.knowledge_embedding_batch(chunk_docs) doc.status = SyncStatus.FINISHED.name doc.result = "document embedding success" if vector_ids is not None: @@ -526,25 +529,27 @@ class KnowledgeService: chat_param = { "chat_session_id": uuid.uuid1(), "current_user_input": doc, - "select_param": "summary", - "model_name": CFG.LLM_MODEL, + "select_param": doc, + "model_name": self.model_name, } - from pilot.utils import utils - loop = utils.get_or_create_event_loop() - summary = loop.run_until_complete( - llm_chat_response_nostream( - ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} - ) + from pilot.common.chat_util import run_async_tasks + + summary_iters = run_async_tasks( + [ + llm_chat_response_nostream( + ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param} + ) + ] ) - return summary - def _refine_extract_summary(self, docs, summary: str, max_iteration:int = 5): + return summary_iters[0] + + def _refine_extract_summary(self, docs, summary: str, max_iteration: int = 5): """Extract refine summary by llm""" from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream import uuid - print( - f"initialize summary is :{summary}" - ) + + print(f"initialize summary is :{summary}") outputs = [summary] max_iteration = max_iteration if len(docs) > max_iteration else len(docs) for doc in docs[0:max_iteration]: @@ -552,9 +557,10 @@ class KnowledgeService: "chat_session_id": uuid.uuid1(), "current_user_input": doc, "select_param": summary, - "model_name": CFG.LLM_MODEL, + "model_name": self.model_name, } from pilot.utils import utils + loop = utils.get_or_create_event_loop() summary = loop.run_until_complete( llm_chat_response_nostream( @@ -562,9 +568,7 @@ class KnowledgeService: ) ) outputs.append(summary) - print( - f"iterator is {len(outputs)} current summary is :{summary}" - ) + print(f"iterator is {len(outputs)} current summary is :{summary}") return outputs, summary def _mapreduce_extract_summary(self, docs): @@ -577,6 +581,7 @@ class KnowledgeService: from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream import uuid + tasks = [] max_iteration = 5 if len(docs) == 1: @@ -589,17 +594,23 @@ class KnowledgeService: "chat_session_id": uuid.uuid1(), "current_user_input": doc, "select_param": "summary", - "model_name": CFG.LLM_MODEL, + "model_name": self.model_name, } - tasks.append(llm_chat_response_nostream( - ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} - )) + tasks.append( + llm_chat_response_nostream( + ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} + ) + ) from pilot.common.chat_util import run_async_tasks + summary_iters = run_async_tasks(tasks) from pilot.common.prompt_util import PromptHelper - from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL + from llama_index.prompts.default_prompt_selectors import ( + DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, + ) + prompt_helper = PromptHelper(context_window=2500) - summary_iters = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters) + summary_iters = prompt_helper.repack( + prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters + ) return self._mapreduce_extract_summary(summary_iters) - -