mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-21 01:34:24 +00:00
feat:extract summary
This commit is contained in:
parent
318979a7bf
commit
724456dc3e
@ -15,10 +15,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class RAGGraphEngine:
|
class RAGGraphEngine:
|
||||||
"""Knowledge RAG Graph Engine.
|
"""Knowledge RAG Graph Engine.
|
||||||
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
Build a RAG Graph Client can extract triplets and insert into graph store.
|
||||||
Args:
|
Args:
|
||||||
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
|
||||||
extracting triplets.
|
extracting triplets.
|
||||||
|
knowledge_source (Optional[str]):
|
||||||
|
model_name (Optional[str]): llm model name
|
||||||
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
|
||||||
include_embeddings (bool): Whether to include embeddings in the index.
|
include_embeddings (bool): Whether to include embeddings in the index.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
@ -104,37 +106,64 @@ class RAGGraphEngine:
|
|||||||
return triplets
|
return triplets
|
||||||
|
|
||||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||||
"""Build the index from nodes."""
|
"""Build the index from nodes.
|
||||||
|
Args:documents:List[Document]
|
||||||
|
"""
|
||||||
index_struct = self.index_struct_cls()
|
index_struct = self.index_struct_cls()
|
||||||
num_threads = 5
|
triplets = []
|
||||||
chunk_size = (
|
for doc in documents:
|
||||||
len(documents)
|
trips = self._extract_triplets_task([doc], index_struct)
|
||||||
if (len(documents) < num_threads)
|
triplets.extend(trips)
|
||||||
else len(documents) // num_threads
|
print(triplets)
|
||||||
)
|
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||||
|
for triplet in triplets:
|
||||||
import concurrent
|
subj, _, obj = triplet
|
||||||
|
self.graph_store.upsert_triplet(*triplet)
|
||||||
future_tasks = []
|
index_struct.add_node([subj, obj], text_node)
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
return index_struct
|
||||||
for i in range(num_threads):
|
# num_threads = 5
|
||||||
start = i * chunk_size
|
# chunk_size = (
|
||||||
end = start + chunk_size if i < num_threads - 1 else None
|
# len(documents)
|
||||||
future_tasks.append(
|
# if (len(documents) < num_threads)
|
||||||
executor.submit(
|
# else len(documents) // num_threads
|
||||||
self._extract_triplets_task,
|
# )
|
||||||
documents[start:end][0],
|
#
|
||||||
index_struct,
|
# import concurrent
|
||||||
)
|
# triples = []
|
||||||
)
|
# future_tasks = []
|
||||||
|
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
result = [future.result() for future in future_tasks]
|
# for i in range(num_threads):
|
||||||
|
# start = i * chunk_size
|
||||||
|
# end = start + chunk_size if i < num_threads - 1 else None
|
||||||
|
# # doc = documents[start:end]
|
||||||
|
# future_tasks.append(
|
||||||
|
# executor.submit(
|
||||||
|
# self._extract_triplets_task,
|
||||||
|
# documents[start:end],
|
||||||
|
# index_struct,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# # for doc in documents[start:end]:
|
||||||
|
# # future_tasks.append(
|
||||||
|
# # executor.submit(
|
||||||
|
# # self._extract_triplets_task,
|
||||||
|
# # doc,
|
||||||
|
# # index_struct,
|
||||||
|
# # )
|
||||||
|
# # )
|
||||||
|
#
|
||||||
|
# # result = [future.result() for future in future_tasks]
|
||||||
|
# completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
|
||||||
|
# for future in completed_futures:
|
||||||
|
# # 获取已完成的future的结果并添加到results列表中
|
||||||
|
# result = future.result()
|
||||||
|
# triplets.extend(result)
|
||||||
|
# print(f"total triplets-{triples}")
|
||||||
# for triplet in triplets:
|
# for triplet in triplets:
|
||||||
# subj, _, obj = triplet
|
# subj, _, obj = triplet
|
||||||
# self.graph_store.upsert_triplet(*triplet)
|
# self.graph_store.upsert_triplet(*triplet)
|
||||||
# self.graph_store.upsert_triplet(*triplet)
|
# # index_struct.add_node([subj, obj], text_node)
|
||||||
# index_struct.add_node([subj, obj], text_node)
|
# return index_struct
|
||||||
return index_struct
|
|
||||||
# for doc in documents:
|
# for doc in documents:
|
||||||
# triplets = self._extract_triplets(doc.page_content)
|
# triplets = self._extract_triplets(doc.page_content)
|
||||||
# if len(triplets) == 0:
|
# if len(triplets) == 0:
|
||||||
@ -154,9 +183,10 @@ class RAGGraphEngine:
|
|||||||
graph_search = RAGGraphSearch(graph_engine=self)
|
graph_search = RAGGraphSearch(graph_engine=self)
|
||||||
return graph_search.search(query)
|
return graph_search.search(query)
|
||||||
|
|
||||||
def _extract_triplets_task(self, doc, index_struct):
|
def _extract_triplets_task(self, docs, index_struct):
|
||||||
|
triple_results = []
|
||||||
|
for doc in docs:
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
thread_id = threading.get_ident()
|
thread_id = threading.get_ident()
|
||||||
print(f"current thread-{thread_id} begin extract triplets task")
|
print(f"current thread-{thread_id} begin extract triplets task")
|
||||||
triplets = self._extract_triplets(doc.page_content)
|
triplets = self._extract_triplets(doc.page_content)
|
||||||
@ -167,7 +197,8 @@ class RAGGraphEngine:
|
|||||||
print(
|
print(
|
||||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
||||||
)
|
)
|
||||||
return triplets
|
triple_results.extend(triplets)
|
||||||
|
return triple_results
|
||||||
# for triplet in triplets:
|
# for triplet in triplets:
|
||||||
# subj, _, obj = triplet
|
# subj, _, obj = triplet
|
||||||
# self.graph_store.upsert_triplet(*triplet)
|
# self.graph_store.upsert_triplet(*triplet)
|
||||||
|
@ -8,7 +8,6 @@ from langchain.schema import Document
|
|||||||
|
|
||||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||||
from pilot.utils import utils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DEFAULT_NODE_SCORE = 1000.0
|
DEFAULT_NODE_SCORE = 1000.0
|
||||||
@ -113,15 +112,15 @@ class RAGGraphSearch(BaseSearch):
|
|||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
keyword = keyword.lower()
|
keyword = keyword.lower()
|
||||||
subjs = set((keyword,))
|
subjs = set((keyword,))
|
||||||
node_ids = self._index_struct.search_node_by_keyword(keyword)
|
# node_ids = self._index_struct.search_node_by_keyword(keyword)
|
||||||
for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
||||||
if node_id in node_visited:
|
# if node_id in node_visited:
|
||||||
continue
|
# continue
|
||||||
|
#
|
||||||
# if self._include_text:
|
# # if self._include_text:
|
||||||
# chunk_indices_count[node_id] += 1
|
# # chunk_indices_count[node_id] += 1
|
||||||
|
#
|
||||||
node_visited.add(node_id)
|
# node_visited.add(node_id)
|
||||||
|
|
||||||
rel_map = self._graph_store.get_rel_map(
|
rel_map = self._graph_store.get_rel_map(
|
||||||
list(subjs), self.graph_store_query_depth
|
list(subjs), self.graph_store_query_depth
|
||||||
|
@ -89,6 +89,13 @@ class ChatScene(Enum):
|
|||||||
["Extract Select"],
|
["Extract Select"],
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
ExtractSummary = Scene(
|
||||||
|
"extract_summary",
|
||||||
|
"Extract Summary",
|
||||||
|
"Extract Summary",
|
||||||
|
["Extract Select"],
|
||||||
|
True,
|
||||||
|
)
|
||||||
ExtractEntity = Scene(
|
ExtractEntity = Scene(
|
||||||
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
|
"extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,7 @@ class ChatFactory(metaclass=Singleton):
|
|||||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||||
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||||
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||||
|
from pilot.scene.chat_knowledge.summary.chat import ExtractSummary
|
||||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||||
from pilot.scene.chat_agent.chat import ChatAgent
|
from pilot.scene.chat_agent.chat import ChatAgent
|
||||||
|
|
||||||
|
@ -280,12 +280,6 @@ class KnowledgeService:
|
|||||||
embedding_factory=embedding_factory,
|
embedding_factory=embedding_factory,
|
||||||
)
|
)
|
||||||
chunk_docs = client.read()
|
chunk_docs = client.read()
|
||||||
from pilot.graph_engine.graph_factory import RAGGraphFactory
|
|
||||||
|
|
||||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
|
||||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
|
||||||
).create()
|
|
||||||
rag_engine.knowledge_graph(docs=chunk_docs)
|
|
||||||
# update document status
|
# update document status
|
||||||
doc.status = SyncStatus.RUNNING.name
|
doc.status = SyncStatus.RUNNING.name
|
||||||
doc.chunk_size = len(chunk_docs)
|
doc.chunk_size = len(chunk_docs)
|
||||||
@ -294,8 +288,8 @@ class KnowledgeService:
|
|||||||
executor = CFG.SYSTEM_APP.get_component(
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
).create()
|
).create()
|
||||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
executor.submit(self.async_knowledge_graph, chunk_docs, doc)
|
||||||
|
# executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||||
# save chunk details
|
# save chunk details
|
||||||
chunk_entities = [
|
chunk_entities = [
|
||||||
@ -397,13 +391,40 @@ class KnowledgeService:
|
|||||||
res.total = document_chunk_dao.get_document_chunks_count(query)
|
res.total = document_chunk_dao.get_document_chunks_count(query)
|
||||||
res.page = request.page
|
res.page = request.page
|
||||||
return res
|
return res
|
||||||
|
def async_knowledge_graph(self, chunk_docs, doc):
|
||||||
|
"""async document extract triplets and save into graph db
|
||||||
|
Args:
|
||||||
|
- chunk_docs: List[Document]
|
||||||
|
- doc: KnowledgeDocumentEntity
|
||||||
|
"""
|
||||||
|
for doc in chunk_docs:
|
||||||
|
text = doc.page_content
|
||||||
|
self._llm_extract_summary(text)
|
||||||
|
logger.info(
|
||||||
|
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
||||||
|
)
|
||||||
|
# try:
|
||||||
|
# from pilot.graph_engine.graph_factory import RAGGraphFactory
|
||||||
|
#
|
||||||
|
# rag_engine = CFG.SYSTEM_APP.get_component(
|
||||||
|
# ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||||
|
# ).create()
|
||||||
|
# rag_engine.knowledge_graph(chunk_docs)
|
||||||
|
# doc.status = SyncStatus.FINISHED.name
|
||||||
|
# doc.result = "document build graph success"
|
||||||
|
# except Exception as e:
|
||||||
|
# doc.status = SyncStatus.FAILED.name
|
||||||
|
# doc.result = "document build graph failed" + str(e)
|
||||||
|
# logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||||
|
return knowledge_document_dao.update_knowledge_document(doc)
|
||||||
|
|
||||||
|
|
||||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||||
"""async document embedding into vector db
|
"""async document embedding into vector db
|
||||||
Args:
|
Args:
|
||||||
- client: EmbeddingEngine Client
|
- client: EmbeddingEngine Client
|
||||||
- chunk_docs: List[Document]
|
- chunk_docs: List[Document]
|
||||||
- doc: doc
|
- doc: KnowledgeDocumentEntity
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||||
@ -461,3 +482,24 @@ class KnowledgeService:
|
|||||||
if space.context is not None:
|
if space.context is not None:
|
||||||
return json.loads(spaces[0].context)
|
return json.loads(spaces[0].context)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _llm_extract_summary(self, doc: str):
|
||||||
|
"""Extract triplets from text by llm"""
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": uuid.uuid1(),
|
||||||
|
"current_user_input": doc,
|
||||||
|
"select_param": "summery",
|
||||||
|
"model_name": "proxyllm",
|
||||||
|
}
|
||||||
|
from pilot.utils import utils
|
||||||
|
loop = utils.get_or_create_event_loop()
|
||||||
|
triplets = loop.run_until_complete(
|
||||||
|
llm_chat_response_nostream(
|
||||||
|
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return triplets
|
||||||
|
Loading…
Reference in New Issue
Block a user