From 53b1fc40901cb59ebcdc93caaa852c942ea5f858 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 30 Oct 2023 19:06:09 +0800 Subject: [PATCH] feat:document summary --- pilot/scene/base.py | 7 ++ pilot/scene/base_chat.py | 2 - pilot/scene/chat_factory.py | 1 + pilot/server/knowledge/document_db.py | 3 +- pilot/server/knowledge/request/response.py | 3 +- pilot/server/knowledge/service.py | 89 ++++++++++++++++------ 6 files changed, 79 insertions(+), 26 deletions(-) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 5c98003d9..e3478f7c3 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -96,6 +96,13 @@ class ChatScene(Enum): ["Extract Select"], True, ) + ExtractRefineSummary = Scene( + "extract_refine_summary", + "Extract Summary", + "Extract Summary", + ["Extract Select"], + True, + ) ExtractEntity = Scene( "extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True ) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 73a2c5ef6..e43cdc812 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -127,8 +127,6 @@ class BaseChat(ABC): speak_to_user = prompt_define_response return speak_to_user - async def __call_base(self): - input_values = await self.generate_input_values() async def __call_base(self): import inspect diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 2e103f15d..10a588c04 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -17,6 +17,7 @@ class ChatFactory(metaclass=Singleton): from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity from pilot.scene.chat_knowledge.summary.chat import ExtractSummary + from pilot.scene.chat_knowledge.refine_summary.chat import ExtractRefineSummary from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel from pilot.scene.chat_agent.chat import ChatAgent diff --git a/pilot/server/knowledge/document_db.py b/pilot/server/knowledge/document_db.py index 3e6dfb0c4..bbe1426d7 100644 --- a/pilot/server/knowledge/document_db.py +++ b/pilot/server/knowledge/document_db.py @@ -30,11 +30,12 @@ class KnowledgeDocumentEntity(Base): content = Column(Text) result = Column(Text) vector_ids = Column(Text) + summary = Column(Text) gmt_created = Column(DateTime) gmt_modified = Column(DateTime) def __repr__(self): - return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', summary='{self.summary}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" class KnowledgeDocumentDao(BaseDao): diff --git a/pilot/server/knowledge/request/response.py b/pilot/server/knowledge/request/response.py index fb7aa55e9..2e3e5f0ab 100644 --- a/pilot/server/knowledge/request/response.py +++ b/pilot/server/knowledge/request/response.py @@ -5,8 +5,9 @@ from pydantic import BaseModel class ChunkQueryResponse(BaseModel): """data: data""" - data: List = None + """summary: document summary""" + summary: str = None """total: total size""" total: int = None """page: current page""" diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 4c1c41994..017fef3ec 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -288,8 +288,8 @@ class KnowledgeService: executor = CFG.SYSTEM_APP.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() - executor.submit(self.async_knowledge_graph, chunk_docs, doc) - # executor.submit(self.async_doc_embedding, client, chunk_docs, doc) + executor.submit(self.async_document_summary, chunk_docs, doc) + executor.submit(self.async_doc_embedding, client, chunk_docs, doc) logger.info(f"begin save document chunks, doc:{doc.doc_name}") # save chunk details chunk_entities = [ @@ -384,38 +384,59 @@ class KnowledgeService: doc_name=request.doc_name, doc_type=request.doc_type, ) + document_query = KnowledgeDocumentEntity(id=request.document_id) + documents = knowledge_document_dao.get_documents(document_query) + res = ChunkQueryResponse() res.data = document_chunk_dao.get_document_chunks( query, page=request.page, page_size=request.page_size ) + res.summary = documents[0].summary res.total = document_chunk_dao.get_document_chunks_count(query) res.page = request.page return res + def async_knowledge_graph(self, chunk_docs, doc): """async document extract triplets and save into graph db Args: - chunk_docs: List[Document] - doc: KnowledgeDocumentEntity """ - for doc in chunk_docs: - text = doc.page_content - self._llm_extract_summary(text) logger.info( f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" ) - # try: - # from pilot.graph_engine.graph_factory import RAGGraphFactory - # - # rag_engine = CFG.SYSTEM_APP.get_component( - # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory - # ).create() - # rag_engine.knowledge_graph(chunk_docs) - # doc.status = SyncStatus.FINISHED.name - # doc.result = "document build graph success" - # except Exception as e: - # doc.status = SyncStatus.FAILED.name - # doc.result = "document build graph failed" + str(e) - # logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}") + try: + from pilot.graph_engine.graph_factory import RAGGraphFactory + + rag_engine = CFG.SYSTEM_APP.get_component( + ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + ).create() + rag_engine.knowledge_graph(chunk_docs) + doc.status = SyncStatus.FINISHED.name + doc.result = "document build graph success" + except Exception as e: + doc.status = SyncStatus.FAILED.name + doc.result = "document build graph failed" + str(e) + logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}") + return knowledge_document_dao.update_knowledge_document(doc) + + def async_document_summary(self, chunk_docs, doc): + """async document extract summary + Args: + - chunk_docs: List[Document] + - doc: KnowledgeDocumentEntity + """ + from llama_index import PromptHelper + from llama_index.prompts.default_prompt_selectors import DEFAULT_TREE_SUMMARIZE_PROMPT_SEL + texts = [doc.page_content for doc in chunk_docs] + prompt_helper = PromptHelper() + texts = prompt_helper.repack(prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts) + summary = self._llm_extract_summary(chunk_docs[0]) + outputs, summary = self._refine_extract_summary(texts[1:], summary) + logger.info( + f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" + ) + doc.summary = summary return knowledge_document_dao.update_knowledge_document(doc) @@ -491,15 +512,39 @@ class KnowledgeService: chat_param = { "chat_session_id": uuid.uuid1(), - "current_user_input": doc, - "select_param": "summery", + "current_user_input": doc.page_content, + "select_param": "summary", "model_name": "proxyllm", } from pilot.utils import utils loop = utils.get_or_create_event_loop() - triplets = loop.run_until_complete( + summary = loop.run_until_complete( llm_chat_response_nostream( ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ) ) - return triplets + return summary + def _refine_extract_summary(self, docs, summary: str): + """Extract refine summary by llm""" + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + outputs = [] + for doc in docs: + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": doc, + "select_param": summary, + "model_name": "proxyllm", + } + from pilot.utils import utils + loop = utils.get_or_create_event_loop() + summary = loop.run_until_complete( + llm_chat_response_nostream( + ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param} + ) + ) + outputs.append(summary) + return outputs, summary + +