mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 01:07:15 +00:00
feat:extract summary
This commit is contained in:
parent
318979a7bf
commit
724456dc3e
@ -15,10 +15,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class RAGGraphEngine:
|
||||
"""Knowledge RAG Graph Engine.
|
||||
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
||||
Build a RAG Graph Client can extract triplets and insert into graph store.
|
||||
Args:
|
||||
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
||||
extracting triplets.
|
||||
knowledge_source (Optional[str]):
|
||||
model_name (Optional[str]): llm model name
|
||||
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
||||
include_embeddings (bool): Whether to include embeddings in the index.
|
||||
Defaults to False.
|
||||
@ -104,37 +106,64 @@ class RAGGraphEngine:
|
||||
return triplets
|
||||
|
||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||
"""Build the index from nodes."""
|
||||
"""Build the index from nodes.
|
||||
Args:documents:List[Document]
|
||||
"""
|
||||
index_struct = self.index_struct_cls()
|
||||
num_threads = 5
|
||||
chunk_size = (
|
||||
len(documents)
|
||||
if (len(documents) < num_threads)
|
||||
else len(documents) // num_threads
|
||||
)
|
||||
|
||||
import concurrent
|
||||
|
||||
future_tasks = []
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for i in range(num_threads):
|
||||
start = i * chunk_size
|
||||
end = start + chunk_size if i < num_threads - 1 else None
|
||||
future_tasks.append(
|
||||
executor.submit(
|
||||
self._extract_triplets_task,
|
||||
documents[start:end][0],
|
||||
index_struct,
|
||||
)
|
||||
)
|
||||
|
||||
result = [future.result() for future in future_tasks]
|
||||
triplets = []
|
||||
for doc in documents:
|
||||
trips = self._extract_triplets_task([doc], index_struct)
|
||||
triplets.extend(trips)
|
||||
print(triplets)
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
for triplet in triplets:
|
||||
subj, _, obj = triplet
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
index_struct.add_node([subj, obj], text_node)
|
||||
return index_struct
|
||||
# num_threads = 5
|
||||
# chunk_size = (
|
||||
# len(documents)
|
||||
# if (len(documents) < num_threads)
|
||||
# else len(documents) // num_threads
|
||||
# )
|
||||
#
|
||||
# import concurrent
|
||||
# triples = []
|
||||
# future_tasks = []
|
||||
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# for i in range(num_threads):
|
||||
# start = i * chunk_size
|
||||
# end = start + chunk_size if i < num_threads - 1 else None
|
||||
# # doc = documents[start:end]
|
||||
# future_tasks.append(
|
||||
# executor.submit(
|
||||
# self._extract_triplets_task,
|
||||
# documents[start:end],
|
||||
# index_struct,
|
||||
# )
|
||||
# )
|
||||
# # for doc in documents[start:end]:
|
||||
# # future_tasks.append(
|
||||
# # executor.submit(
|
||||
# # self._extract_triplets_task,
|
||||
# # doc,
|
||||
# # index_struct,
|
||||
# # )
|
||||
# # )
|
||||
#
|
||||
# # result = [future.result() for future in future_tasks]
|
||||
# completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
# for future in completed_futures:
|
||||
# # 获取已完成的future的结果并添加到results列表中
|
||||
# result = future.result()
|
||||
# triplets.extend(result)
|
||||
# print(f"total triplets-{triples}")
|
||||
# for triplet in triplets:
|
||||
# subj, _, obj = triplet
|
||||
# self.graph_store.upsert_triplet(*triplet)
|
||||
# self.graph_store.upsert_triplet(*triplet)
|
||||
# index_struct.add_node([subj, obj], text_node)
|
||||
return index_struct
|
||||
# # index_struct.add_node([subj, obj], text_node)
|
||||
# return index_struct
|
||||
# for doc in documents:
|
||||
# triplets = self._extract_triplets(doc.page_content)
|
||||
# if len(triplets) == 0:
|
||||
@ -154,20 +183,22 @@ class RAGGraphEngine:
|
||||
graph_search = RAGGraphSearch(graph_engine=self)
|
||||
return graph_search.search(query)
|
||||
|
||||
def _extract_triplets_task(self, doc, index_struct):
|
||||
import threading
|
||||
|
||||
thread_id = threading.get_ident()
|
||||
print(f"current thread-{thread_id} begin extract triplets task")
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
if len(triplets) == 0:
|
||||
triplets = []
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||
print(
|
||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
||||
)
|
||||
return triplets
|
||||
def _extract_triplets_task(self, docs, index_struct):
|
||||
triple_results = []
|
||||
for doc in docs:
|
||||
import threading
|
||||
thread_id = threading.get_ident()
|
||||
print(f"current thread-{thread_id} begin extract triplets task")
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
if len(triplets) == 0:
|
||||
triplets = []
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||
print(
|
||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
||||
)
|
||||
triple_results.extend(triplets)
|
||||
return triple_results
|
||||
# for triplet in triplets:
|
||||
# subj, _, obj = triplet
|
||||
# self.graph_store.upsert_triplet(*triplet)
|
||||
|
@ -8,7 +8,6 @@ from langchain.schema import Document
|
||||
|
||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||
from pilot.utils import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_NODE_SCORE = 1000.0
|
||||
@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch):
|
||||
for keyword in keywords:
|
||||
keyword = keyword.lower()
|
||||
subjs = set((keyword,))
|
||||
node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||
if node_id in node_visited:
|
||||
continue
|
||||
|
||||
# if self._include_text:
|
||||
# chunk_indices_count[node_id] += 1
|
||||
|
||||
node_visited.add(node_id)
|
||||
# node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||
# if node_id in node_visited:
|
||||
# continue
|
||||
#
|
||||
# # if self._include_text:
|
||||
# # chunk_indices_count[node_id] += 1
|
||||
#
|
||||
# node_visited.add(node_id)
|
||||
|
||||
rel_map = self._graph_store.get_rel_map(
|
||||
list(subjs), self.graph_store_query_depth
|
||||
|
@ -89,6 +89,13 @@ class ChatScene(Enum):
|
||||
["Extract Select"],
|
||||
True,
|
||||
)
|
||||
ExtractSummary = Scene(
|
||||
"extract_summary",
|
||||
"Extract Summary",
|
||||
"Extract Summary",
|
||||
["Extract Select"],
|
||||
True,
|
||||
)
|
||||
ExtractEntity = Scene(
|
||||
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton):
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||
from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
from pilot.scene.chat_agent.chat import ChatAgent
|
||||
|
||||
|
@ -280,12 +280,6 @@ class KnowledgeService:
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
|
||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
rag_engine.knowledge_graph(docs=chunk_docs)
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
@ -294,8 +288,8 @@ class KnowledgeService:
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
|
||||
executor.submit(self.async_knowledge_graph, chunk_docs, doc)
|
||||
# executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
@ -397,13 +391,40 @@ class KnowledgeService:
|
||||
res.total = document_chunk_dao.get_document_chunks_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
def async_knowledge_graph(self, chunk_docs, doc):
|
||||
"""async document extract triplets and save into graph db
|
||||
Args:
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
for doc in chunk_docs:
|
||||
text = doc.page_content
|
||||
self._llm_extract_summary(text)
|
||||
logger.info(
|
||||
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
||||
)
|
||||
# try:
|
||||
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||
#
|
||||
# rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
# ).create()
|
||||
# rag_engine.knowledge_graph(chunk_docs)
|
||||
# doc.status = SyncStatus.FINISHED.name
|
||||
# doc.result = "document build graph success"
|
||||
# except Exception as e:
|
||||
# doc.status = SyncStatus.FAILED.name
|
||||
# doc.result = "document build graph failed" + str(e)
|
||||
# logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
- client: EmbeddingEngine Client
|
||||
- chunk_docs: List[Document]
|
||||
- doc: doc
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
logger.info(
|
||||
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
@ -461,3 +482,24 @@ class KnowledgeService:
|
||||
if space.context is not None:
|
||||
return json.loads(spaces[0].context)
|
||||
return None
|
||||
|
||||
def _llm_extract_summary(self, doc: str):
|
||||
"""Extract triplets from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"select_param": "summery",
|
||||
"model_name": "proxyllm",
|
||||
}
|
||||
from pilot.utils import utils
|
||||
loop = utils.get_or_create_event_loop()
|
||||
triplets = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
return triplets
|
||||
|
Loading…
Reference in New Issue
Block a user