feat:extract summary

This commit is contained in:
aries_ckt 2023-10-25 21:18:37 +08:00
parent 318979a7bf
commit 724456dc3e
5 changed files with 141 additions and 61 deletions

View File

@ -15,10 +15,12 @@ logger = logging.getLogger(__name__)
class RAGGraphEngine:
"""Knowledge RAG Graph Engine.
Build a KG by extracting triplets, and leveraging the KG during query-time.
Build a RAG Graph Client can extract triplets and insert into graph store.
Args:
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
extracting triplets.
knowledge_source (Optional[str]):
model_name (Optional[str]): llm model name
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
include_embeddings (bool): Whether to include embeddings in the index.
Defaults to False.
@ -104,37 +106,64 @@ class RAGGraphEngine:
return triplets
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
"""Build the index from nodes.
Args:documents:List[Document]
"""
index_struct = self.index_struct_cls()
num_threads = 5
chunk_size = (
len(documents)
if (len(documents) < num_threads)
else len(documents) // num_threads
)
import concurrent
future_tasks = []
with concurrent.futures.ThreadPoolExecutor() as executor:
for i in range(num_threads):
start = i * chunk_size
end = start + chunk_size if i < num_threads - 1 else None
future_tasks.append(
executor.submit(
self._extract_triplets_task,
documents[start:end][0],
index_struct,
)
)
result = [future.result() for future in future_tasks]
triplets = []
for doc in documents:
trips = self._extract_triplets_task([doc], index_struct)
triplets.extend(trips)
print(triplets)
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
return index_struct
# num_threads = 5
# chunk_size = (
# len(documents)
# if (len(documents) < num_threads)
# else len(documents) // num_threads
# )
#
# import concurrent
# triples = []
# future_tasks = []
# with concurrent.futures.ThreadPoolExecutor() as executor:
# for i in range(num_threads):
# start = i * chunk_size
# end = start + chunk_size if i < num_threads - 1 else None
# # doc = documents[start:end]
# future_tasks.append(
# executor.submit(
# self._extract_triplets_task,
# documents[start:end],
# index_struct,
# )
# )
# # for doc in documents[start:end]:
# # future_tasks.append(
# # executor.submit(
# # self._extract_triplets_task,
# # doc,
# # index_struct,
# # )
# # )
#
# # result = [future.result() for future in future_tasks]
# completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
# for future in completed_futures:
# # 获取已完成的future的结果并添加到results列表中
# result = future.result()
# triplets.extend(result)
# print(f"total triplets-{triples}")
# for triplet in triplets:
# subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet)
# self.graph_store.upsert_triplet(*triplet)
# index_struct.add_node([subj, obj], text_node)
return index_struct
# # index_struct.add_node([subj, obj], text_node)
# return index_struct
# for doc in documents:
# triplets = self._extract_triplets(doc.page_content)
# if len(triplets) == 0:
@ -154,20 +183,22 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
def _extract_triplets_task(self, doc, index_struct):
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
)
return triplets
def _extract_triplets_task(self, docs, index_struct):
triple_results = []
for doc in docs:
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
)
triple_results.extend(triplets)
return triple_results
# for triplet in triplets:
# subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet)

View File

@ -8,7 +8,6 @@ from langchain.schema import Document
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils
logger = logging.getLogger(__name__)
DEFAULT_NODE_SCORE = 1000.0
@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch):
for keyword in keywords:
keyword = keyword.lower()
subjs = set((keyword,))
node_ids = self._index_struct.search_node_by_keyword(keyword)
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
if node_id in node_visited:
continue
# if self._include_text:
# chunk_indices_count[node_id] += 1
node_visited.add(node_id)
# node_ids = self._index_struct.search_node_by_keyword(keyword)
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
# if node_id in node_visited:
# continue
#
# # if self._include_text:
# # chunk_indices_count[node_id] += 1
#
# node_visited.add(node_id)
rel_map = self._graph_store.get_rel_map(
list(subjs), self.graph_store_query_depth

View File

@ -89,6 +89,13 @@ class ChatScene(Enum):
["Extract Select"],
True,
)
ExtractSummary = Scene(
"extract_summary",
"Extract Summary",
"Extract Summary",
["Extract Select"],
True,
)
ExtractEntity = Scene(
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
)

View File

@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton):
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
from pilot.scene.chat_agent.chat import ChatAgent

View File

@ -280,12 +280,6 @@ class KnowledgeService:
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
from pilot.graph_engine.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(docs=chunk_docs)
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
@ -294,8 +288,8 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
executor.submit(self.async_knowledge_graph, chunk_docs, doc)
# executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
@ -397,13 +391,40 @@ class KnowledgeService:
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
def async_knowledge_graph(self, chunk_docs, doc):
"""async document extract triplets and save into graph db
Args:
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
for doc in chunk_docs:
text = doc.page_content
self._llm_extract_summary(text)
logger.info(
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
# try:
# from pilot.graph_engine.graph_factory import RAGGraphFactory
#
# rag_engine = CFG.SYSTEM_APP.get_component(
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
# ).create()
# rag_engine.knowledge_graph(chunk_docs)
# doc.status = SyncStatus.FINISHED.name
# doc.result = "document build graph success"
# except Exception as e:
# doc.status = SyncStatus.FAILED.name
# doc.result = "document build graph failed" + str(e)
# logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
- chunk_docs: List[Document]
- doc: doc
- doc: KnowledgeDocumentEntity
"""
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
@ -461,3 +482,24 @@ class KnowledgeService:
if space.context is not None:
return json.loads(spaces[0].context)
return None
def _llm_extract_summary(self, doc: str):
"""Extract triplets from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summery",
"model_name": "proxyllm",
}
from pilot.utils import utils
loop = utils.get_or_create_event_loop()
triplets = loop.run_until_complete(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
return triplets