From 724456dc3e82e9bc53d9cdf21bbd485f4a48e4eb Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 25 Oct 2023 21:18:37 +0800 Subject: [PATCH] feat:extract summary --- pilot/graph_engine/graph_engine.py | 115 ++++++++++++++++++----------- pilot/graph_engine/graph_search.py | 19 +++-- pilot/scene/base.py | 7 ++ pilot/scene/chat_factory.py | 1 + pilot/server/knowledge/service.py | 60 ++++++++++++--- 5 files changed, 141 insertions(+), 61 deletions(-) diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py index faf53aba1..491a8625c 100644 --- a/pilot/graph_engine/graph_engine.py +++ b/pilot/graph_engine/graph_engine.py @@ -15,10 +15,12 @@ logger = logging.getLogger(__name__) class RAGGraphEngine: """Knowledge RAG Graph Engine. - Build a KG by extracting triplets, and leveraging the KG during query-time. + Build a RAG Graph Client can extract triplets and insert into graph store. Args: knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value extracting triplets. + knowledge_source (Optional[str]): + model_name (Optional[str]): llm model name graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index include_embeddings (bool): Whether to include embeddings in the index. Defaults to False. @@ -104,37 +106,64 @@ class RAGGraphEngine: return triplets def _build_index_from_docs(self, documents: List[Document]) -> KG: - """Build the index from nodes.""" + """Build the index from nodes. + Args:documents:List[Document] + """ index_struct = self.index_struct_cls() - num_threads = 5 - chunk_size = ( - len(documents) - if (len(documents) < num_threads) - else len(documents) // num_threads - ) - - import concurrent - - future_tasks = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - for i in range(num_threads): - start = i * chunk_size - end = start + chunk_size if i < num_threads - 1 else None - future_tasks.append( - executor.submit( - self._extract_triplets_task, - documents[start:end][0], - index_struct, - ) - ) - - result = [future.result() for future in future_tasks] + triplets = [] + for doc in documents: + trips = self._extract_triplets_task([doc], index_struct) + triplets.extend(trips) + print(triplets) + text_node = TextNode(text=doc.page_content, metadata=doc.metadata) + for triplet in triplets: + subj, _, obj = triplet + self.graph_store.upsert_triplet(*triplet) + index_struct.add_node([subj, obj], text_node) + return index_struct + # num_threads = 5 + # chunk_size = ( + # len(documents) + # if (len(documents) < num_threads) + # else len(documents) // num_threads + # ) + # + # import concurrent + # triples = [] + # future_tasks = [] + # with concurrent.futures.ThreadPoolExecutor() as executor: + # for i in range(num_threads): + # start = i * chunk_size + # end = start + chunk_size if i < num_threads - 1 else None + # # doc = documents[start:end] + # future_tasks.append( + # executor.submit( + # self._extract_triplets_task, + # documents[start:end], + # index_struct, + # ) + # ) + # # for doc in documents[start:end]: + # # future_tasks.append( + # # executor.submit( + # # self._extract_triplets_task, + # # doc, + # # index_struct, + # # ) + # # ) + # + # # result = [future.result() for future in future_tasks] + # completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED) + # for future in completed_futures: + # # 获取已完成的future的结果并添加到results列表中 + # result = future.result() + # triplets.extend(result) + # print(f"total triplets-{triples}") # for triplet in triplets: # subj, _, obj = triplet # self.graph_store.upsert_triplet(*triplet) - # self.graph_store.upsert_triplet(*triplet) - # index_struct.add_node([subj, obj], text_node) - return index_struct + # # index_struct.add_node([subj, obj], text_node) + # return index_struct # for doc in documents: # triplets = self._extract_triplets(doc.page_content) # if len(triplets) == 0: @@ -154,20 +183,22 @@ class RAGGraphEngine: graph_search = RAGGraphSearch(graph_engine=self) return graph_search.search(query) - def _extract_triplets_task(self, doc, index_struct): - import threading - - thread_id = threading.get_ident() - print(f"current thread-{thread_id} begin extract triplets task") - triplets = self._extract_triplets(doc.page_content) - if len(triplets) == 0: - triplets = [] - text_node = TextNode(text=doc.page_content, metadata=doc.metadata) - logger.info(f"extracted knowledge triplets: {triplets}") - print( - f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" - ) - return triplets + def _extract_triplets_task(self, docs, index_struct): + triple_results = [] + for doc in docs: + import threading + thread_id = threading.get_ident() + print(f"current thread-{thread_id} begin extract triplets task") + triplets = self._extract_triplets(doc.page_content) + if len(triplets) == 0: + triplets = [] + text_node = TextNode(text=doc.page_content, metadata=doc.metadata) + logger.info(f"extracted knowledge triplets: {triplets}") + print( + f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" + ) + triple_results.extend(triplets) + return triple_results # for triplet in triplets: # subj, _, obj = triplet # self.graph_store.upsert_triplet(*triplet) diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py index fb883e48b..f3025be85 100644 --- a/pilot/graph_engine/graph_search.py +++ b/pilot/graph_engine/graph_search.py @@ -8,7 +8,6 @@ from langchain.schema import Document from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore from pilot.graph_engine.search import BaseSearch, SearchMode -from pilot.utils import utils logger = logging.getLogger(__name__) DEFAULT_NODE_SCORE = 1000.0 @@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch): for keyword in keywords: keyword = keyword.lower() subjs = set((keyword,)) - node_ids = self._index_struct.search_node_by_keyword(keyword) - for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]: - if node_id in node_visited: - continue - - # if self._include_text: - # chunk_indices_count[node_id] += 1 - - node_visited.add(node_id) + # node_ids = self._index_struct.search_node_by_keyword(keyword) + # for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]: + # if node_id in node_visited: + # continue + # + # # if self._include_text: + # # chunk_indices_count[node_id] += 1 + # + # node_visited.add(node_id) rel_map = self._graph_store.get_rel_map( list(subjs), self.graph_store_query_depth diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 6abc9c937..5c98003d9 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -89,6 +89,13 @@ class ChatScene(Enum): ["Extract Select"], True, ) + ExtractSummary = Scene( + "extract_summary", + "Extract Summary", + "Extract Summary", + ["Extract Select"], + True, + ) ExtractEntity = Scene( "extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True ) diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index de11332f5..a57855a2b 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton): from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity + from pilot.scene.chat_knowledge.summary.chat import ExtractSummary from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel from pilot.scene.chat_agent.chat import ChatAgent diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index ed8c2846e..4c1c41994 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -280,12 +280,6 @@ class KnowledgeService: embedding_factory=embedding_factory, ) chunk_docs = client.read() - from pilot.graph_engine.graph_factory import RAGGraphFactory - - rag_engine = CFG.SYSTEM_APP.get_component( - ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory - ).create() - rag_engine.knowledge_graph(docs=chunk_docs) # update document status doc.status = SyncStatus.RUNNING.name doc.chunk_size = len(chunk_docs) @@ -294,8 +288,8 @@ class KnowledgeService: executor = CFG.SYSTEM_APP.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() - executor.submit(self.async_doc_embedding, client, chunk_docs, doc) - + executor.submit(self.async_knowledge_graph, chunk_docs, doc) + # executor.submit(self.async_doc_embedding, client, chunk_docs, doc) logger.info(f"begin save document chunks, doc:{doc.doc_name}") # save chunk details chunk_entities = [ @@ -397,13 +391,40 @@ class KnowledgeService: res.total = document_chunk_dao.get_document_chunks_count(query) res.page = request.page return res + def async_knowledge_graph(self, chunk_docs, doc): + """async document extract triplets and save into graph db + Args: + - chunk_docs: List[Document] + - doc: KnowledgeDocumentEntity + """ + for doc in chunk_docs: + text = doc.page_content + self._llm_extract_summary(text) + logger.info( + f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" + ) + # try: + # from pilot.graph_engine.graph_factory import RAGGraphFactory + # + # rag_engine = CFG.SYSTEM_APP.get_component( + # ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory + # ).create() + # rag_engine.knowledge_graph(chunk_docs) + # doc.status = SyncStatus.FINISHED.name + # doc.result = "document build graph success" + # except Exception as e: + # doc.status = SyncStatus.FAILED.name + # doc.result = "document build graph failed" + str(e) + # logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}") + return knowledge_document_dao.update_knowledge_document(doc) + def async_doc_embedding(self, client, chunk_docs, doc): """async document embedding into vector db Args: - client: EmbeddingEngine Client - chunk_docs: List[Document] - - doc: doc + - doc: KnowledgeDocumentEntity """ logger.info( f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" @@ -461,3 +482,24 @@ class KnowledgeService: if space.context is not None: return json.loads(spaces[0].context) return None + + def _llm_extract_summary(self, doc: str): + """Extract triplets from text by llm""" + from pilot.scene.base import ChatScene + from pilot.common.chat_util import llm_chat_response_nostream + import uuid + + chat_param = { + "chat_session_id": uuid.uuid1(), + "current_user_input": doc, + "select_param": "summery", + "model_name": "proxyllm", + } + from pilot.utils import utils + loop = utils.get_or_create_event_loop() + triplets = loop.run_until_complete( + llm_chat_response_nostream( + ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} + ) + ) + return triplets