feat:extract summary

This commit is contained in:
aries_ckt 2023-10-25 21:18:37 +08:00
parent 318979a7bf
commit 724456dc3e
5 changed files with 141 additions and 61 deletions

View File

@ -15,10 +15,12 @@ logger = logging.getLogger(__name__)
class RAGGraphEngine: class RAGGraphEngine:
"""Knowledge RAG Graph Engine. """Knowledge RAG Graph Engine.
Build a KG by extracting triplets, and leveraging the KG during query-time. Build a RAG Graph Client can extract triplets and insert into graph store.
Args: Args:
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
extracting triplets. extracting triplets.
knowledge_source (Optional[str]):
model_name (Optional[str]): llm model name
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
include_embeddings (bool): Whether to include embeddings in the index. include_embeddings (bool): Whether to include embeddings in the index.
Defaults to False. Defaults to False.
@ -104,37 +106,64 @@ class RAGGraphEngine:
return triplets return triplets
def _build_index_from_docs(self, documents: List[Document]) -> KG: def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes.""" """Build the index from nodes.
Args:documents:List[Document]
"""
index_struct = self.index_struct_cls() index_struct = self.index_struct_cls()
num_threads = 5 triplets = []
chunk_size = ( for doc in documents:
len(documents) trips = self._extract_triplets_task([doc], index_struct)
if (len(documents) < num_threads) triplets.extend(trips)
else len(documents) // num_threads print(triplets)
) text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
for triplet in triplets:
import concurrent subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
future_tasks = [] index_struct.add_node([subj, obj], text_node)
with concurrent.futures.ThreadPoolExecutor() as executor: return index_struct
for i in range(num_threads): # num_threads = 5
start = i * chunk_size # chunk_size = (
end = start + chunk_size if i < num_threads - 1 else None # len(documents)
future_tasks.append( # if (len(documents) < num_threads)
executor.submit( # else len(documents) // num_threads
self._extract_triplets_task, # )
documents[start:end][0], #
index_struct, # import concurrent
) # triples = []
) # future_tasks = []
# with concurrent.futures.ThreadPoolExecutor() as executor:
result = [future.result() for future in future_tasks] # for i in range(num_threads):
# start = i * chunk_size
# end = start + chunk_size if i < num_threads - 1 else None
# # doc = documents[start:end]
# future_tasks.append(
# executor.submit(
# self._extract_triplets_task,
# documents[start:end],
# index_struct,
# )
# )
# # for doc in documents[start:end]:
# # future_tasks.append(
# # executor.submit(
# # self._extract_triplets_task,
# # doc,
# # index_struct,
# # )
# # )
#
# # result = [future.result() for future in future_tasks]
# completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
# for future in completed_futures:
# # 获取已完成的future的结果并添加到results列表中
# result = future.result()
# triplets.extend(result)
# print(f"total triplets-{triples}")
# for triplet in triplets: # for triplet in triplets:
# subj, _, obj = triplet # subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet) # self.graph_store.upsert_triplet(*triplet)
# self.graph_store.upsert_triplet(*triplet) # # index_struct.add_node([subj, obj], text_node)
# index_struct.add_node([subj, obj], text_node) # return index_struct
return index_struct
# for doc in documents: # for doc in documents:
# triplets = self._extract_triplets(doc.page_content) # triplets = self._extract_triplets(doc.page_content)
# if len(triplets) == 0: # if len(triplets) == 0:
@ -154,9 +183,10 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self) graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query) return graph_search.search(query)
def _extract_triplets_task(self, doc, index_struct): def _extract_triplets_task(self, docs, index_struct):
triple_results = []
for doc in docs:
import threading import threading
thread_id = threading.get_ident() thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task") print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content) triplets = self._extract_triplets(doc.page_content)
@ -167,7 +197,8 @@ class RAGGraphEngine:
print( print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
) )
return triplets triple_results.extend(triplets)
return triple_results
# for triplet in triplets: # for triplet in triplets:
# subj, _, obj = triplet # subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet) # self.graph_store.upsert_triplet(*triplet)

View File

@ -8,7 +8,6 @@ from langchain.schema import Document
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NODE_SCORE = 1000.0 DEFAULT_NODE_SCORE = 1000.0
@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch):
for keyword in keywords: for keyword in keywords:
keyword = keyword.lower() keyword = keyword.lower()
subjs = set((keyword,)) subjs = set((keyword,))
node_ids = self._index_struct.search_node_by_keyword(keyword) # node_ids = self._index_struct.search_node_by_keyword(keyword)
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]: # for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
if node_id in node_visited: # if node_id in node_visited:
continue # continue
#
# if self._include_text: # # if self._include_text:
# chunk_indices_count[node_id] += 1 # # chunk_indices_count[node_id] += 1
#
node_visited.add(node_id) # node_visited.add(node_id)
rel_map = self._graph_store.get_rel_map( rel_map = self._graph_store.get_rel_map(
list(subjs), self.graph_store_query_depth list(subjs), self.graph_store_query_depth

View File

@ -89,6 +89,13 @@ class ChatScene(Enum):
["Extract Select"], ["Extract Select"],
True, True,
) )
ExtractSummary = Scene(
"extract_summary",
"Extract Summary",
"Extract Summary",
["Extract Select"],
True,
)
ExtractEntity = Scene( ExtractEntity = Scene(
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True "extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
) )

View File

@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
from pilot.scene.chat_agent.chat import ChatAgent from pilot.scene.chat_agent.chat import ChatAgent

View File

@ -280,12 +280,6 @@ class KnowledgeService:
embedding_factory=embedding_factory, embedding_factory=embedding_factory,
) )
chunk_docs = client.read() chunk_docs = client.read()
from pilot.graph_engine.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(docs=chunk_docs)
# update document status # update document status
doc.status = SyncStatus.RUNNING.name doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs) doc.chunk_size = len(chunk_docs)
@ -294,8 +288,8 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component( executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() ).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc) executor.submit(self.async_knowledge_graph, chunk_docs, doc)
# executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}") logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details # save chunk details
chunk_entities = [ chunk_entities = [
@ -397,13 +391,40 @@ class KnowledgeService:
res.total = document_chunk_dao.get_document_chunks_count(query) res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page res.page = request.page
return res return res
def async_knowledge_graph(self, chunk_docs, doc):
"""async document extract triplets and save into graph db
Args:
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
for doc in chunk_docs:
text = doc.page_content
self._llm_extract_summary(text)
logger.info(
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
# try:
# from pilot.graph_engine.graph_factory import RAGGraphFactory
#
# rag_engine = CFG.SYSTEM_APP.get_component(
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
# ).create()
# rag_engine.knowledge_graph(chunk_docs)
# doc.status = SyncStatus.FINISHED.name
# doc.result = "document build graph success"
# except Exception as e:
# doc.status = SyncStatus.FAILED.name
# doc.result = "document build graph failed" + str(e)
# logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db """async document embedding into vector db
Args: Args:
- client: EmbeddingEngine Client - client: EmbeddingEngine Client
- chunk_docs: List[Document] - chunk_docs: List[Document]
- doc: doc - doc: KnowledgeDocumentEntity
""" """
logger.info( logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
@ -461,3 +482,24 @@ class KnowledgeService:
if space.context is not None: if space.context is not None:
return json.loads(spaces[0].context) return json.loads(spaces[0].context)
return None return None
def _llm_extract_summary(self, doc: str):
"""Extract triplets from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summery",
"model_name": "proxyllm",
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
triplets = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
return triplets