mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-30 14:19:18 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			208 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			208 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| from typing import Any, Optional, Callable, Tuple, List
 | |
| 
 | |
| from langchain.schema import Document
 | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter
 | |
| 
 | |
| from pilot.embedding_engine import KnowledgeType
 | |
| from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
 | |
| from pilot.graph_engine.index_struct import KG
 | |
| from pilot.graph_engine.node import TextNode
 | |
| from pilot.utils import utils
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class RAGGraphEngine:
 | |
|     """Knowledge RAG Graph Engine.
 | |
|     Build a RAG Graph Client can extract triplets and insert into graph store.
 | |
|     Args:
 | |
|         knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value
 | |
|             extracting triplets.
 | |
|         knowledge_source (Optional[str]):
 | |
|         model_name (Optional[str]): llm model name
 | |
|         graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index
 | |
|         include_embeddings (bool): Whether to include embeddings in the index.
 | |
|             Defaults to False.
 | |
|         max_object_length (int): The maximum length of the object in a triplet.
 | |
|             Defaults to 128.
 | |
|         extract_triplet_fn (Optional[Callable]): The function to use for
 | |
|             extracting triplets. Defaults to None.
 | |
|     """
 | |
| 
 | |
|     index_struct_cls = KG
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
 | |
|         knowledge_source: Optional[str] = None,
 | |
|         text_splitter=None,
 | |
|         graph_store=None,
 | |
|         index_struct: Optional[KG] = None,
 | |
|         model_name: Optional[str] = None,
 | |
|         max_triplets_per_chunk: int = 10,
 | |
|         include_embeddings: bool = False,
 | |
|         max_object_length: int = 128,
 | |
|         extract_triplet_fn: Optional[Callable] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         """Initialize params."""
 | |
|         from llama_index.graph_stores import SimpleGraphStore
 | |
| 
 | |
|         # need to set parameters before building index in base class.
 | |
|         self.knowledge_source = knowledge_source
 | |
|         self.knowledge_type = knowledge_type
 | |
|         self.model_name = model_name
 | |
|         self.text_splitter = text_splitter
 | |
|         self.index_struct = index_struct
 | |
|         self.include_embeddings = include_embeddings
 | |
|         self.graph_store = graph_store or SimpleGraphStore()
 | |
|         # self.graph_store = graph_store
 | |
|         self.max_triplets_per_chunk = max_triplets_per_chunk
 | |
|         self._max_object_length = max_object_length
 | |
|         self._extract_triplet_fn = extract_triplet_fn
 | |
| 
 | |
|     def knowledge_graph(self, docs=None):
 | |
|         """knowledge docs into graph store"""
 | |
|         if not docs:
 | |
|             if self.text_splitter:
 | |
|                 self.text_splitter = RecursiveCharacterTextSplitter(
 | |
|                     chunk_size=2000, chunk_overlap=100
 | |
|                 )
 | |
|             knowledge_source = get_knowledge_embedding(
 | |
|                 knowledge_type=self.knowledge_type,
 | |
|                 knowledge_source=self.knowledge_source,
 | |
|                 text_splitter=self.text_splitter,
 | |
|             )
 | |
|             docs = knowledge_source.read()
 | |
|         if self.index_struct is None:
 | |
|             self.index_struct = self._build_index_from_docs(docs)
 | |
| 
 | |
|     def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
 | |
|         """Extract triplets from text by function or llm"""
 | |
|         if self._extract_triplet_fn is not None:
 | |
|             return self._extract_triplet_fn(text)
 | |
|         else:
 | |
|             return self._llm_extract_triplets(text)
 | |
| 
 | |
|     def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
 | |
|         """Extract triplets from text by llm"""
 | |
|         from pilot.scene.base import ChatScene
 | |
|         from pilot.common.chat_util import llm_chat_response_nostream
 | |
|         import uuid
 | |
| 
 | |
|         chat_param = {
 | |
|             "chat_session_id": uuid.uuid1(),
 | |
|             "current_user_input": text,
 | |
|             "select_param": "triplet",
 | |
|             "model_name": self.model_name,
 | |
|         }
 | |
|         loop = utils.get_or_create_event_loop()
 | |
|         triplets = loop.run_until_complete(
 | |
|             llm_chat_response_nostream(
 | |
|                 ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param}
 | |
|             )
 | |
|         )
 | |
|         return triplets
 | |
| 
 | |
|     def _build_index_from_docs(self, documents: List[Document]) -> KG:
 | |
|         """Build the index from nodes.
 | |
|         Args:documents:List[Document]
 | |
|         """
 | |
|         index_struct = self.index_struct_cls()
 | |
|         triplets = []
 | |
|         for doc in documents:
 | |
|             trips = self._extract_triplets_task([doc], index_struct)
 | |
|             triplets.extend(trips)
 | |
|         print(triplets)
 | |
|         text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
 | |
|         for triplet in triplets:
 | |
|             subj, _, obj = triplet
 | |
|             self.graph_store.upsert_triplet(*triplet)
 | |
|             index_struct.add_node([subj, obj], text_node)
 | |
|         return index_struct
 | |
|         # num_threads = 5
 | |
|         # chunk_size = (
 | |
|         #     len(documents)
 | |
|         #     if (len(documents) < num_threads)
 | |
|         #     else len(documents) // num_threads
 | |
|         # )
 | |
|         #
 | |
|         # import concurrent
 | |
|         # triples = []
 | |
|         # future_tasks = []
 | |
|         # with concurrent.futures.ThreadPoolExecutor() as executor:
 | |
|         #     for i in range(num_threads):
 | |
|         #         start = i * chunk_size
 | |
|         #         end = start + chunk_size if i < num_threads - 1 else None
 | |
|         #         # doc = documents[start:end]
 | |
|         #         future_tasks.append(
 | |
|         #             executor.submit(
 | |
|         #                 self._extract_triplets_task,
 | |
|         #                 documents[start:end],
 | |
|         #                 index_struct,
 | |
|         #             )
 | |
|         #         )
 | |
|         #         # for doc in documents[start:end]:
 | |
|         #         #     future_tasks.append(
 | |
|         #         #         executor.submit(
 | |
|         #         #             self._extract_triplets_task,
 | |
|         #         #             doc,
 | |
|         #         #             index_struct,
 | |
|         #         #         )
 | |
|         #         #     )
 | |
|         #
 | |
|         # # result = [future.result() for future in future_tasks]
 | |
|         # completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED)
 | |
|         # for future in completed_futures:
 | |
|         #     # 获取已完成的future的结果并添加到results列表中
 | |
|         #     result = future.result()
 | |
|         #     triplets.extend(result)
 | |
|         # print(f"total triplets-{triples}")
 | |
|         # for triplet in triplets:
 | |
|         #     subj, _, obj = triplet
 | |
|         #     self.graph_store.upsert_triplet(*triplet)
 | |
|         #     # index_struct.add_node([subj, obj], text_node)
 | |
|         # return index_struct
 | |
|         # for doc in documents:
 | |
|         #     triplets = self._extract_triplets(doc.page_content)
 | |
|         #     if len(triplets) == 0:
 | |
|         #         continue
 | |
|         #     text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
 | |
|         #     logger.info(f"extracted knowledge triplets: {triplets}")
 | |
|         #     for triplet in triplets:
 | |
|         #         subj, _, obj = triplet
 | |
|         #         self.graph_store.upsert_triplet(*triplet)
 | |
|         #         index_struct.add_node([subj, obj], text_node)
 | |
|         #
 | |
|         # return index_struct
 | |
| 
 | |
|     def search(self, query):
 | |
|         from pilot.graph_engine.graph_search import RAGGraphSearch
 | |
| 
 | |
|         graph_search = RAGGraphSearch(graph_engine=self)
 | |
|         return graph_search.search(query)
 | |
| 
 | |
|     def _extract_triplets_task(self, docs, index_struct):
 | |
|         triple_results = []
 | |
|         for doc in docs:
 | |
|             import threading
 | |
| 
 | |
|             thread_id = threading.get_ident()
 | |
|             print(f"current thread-{thread_id} begin extract triplets task")
 | |
|             triplets = self._extract_triplets(doc.page_content)
 | |
|             if len(triplets) == 0:
 | |
|                 triplets = []
 | |
|             text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
 | |
|             logger.info(f"extracted knowledge triplets: {triplets}")
 | |
|             print(
 | |
|                 f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
 | |
|             )
 | |
|             triple_results.extend(triplets)
 | |
|         return triple_results
 | |
|         # for triplet in triplets:
 | |
|         #     subj, _, obj = triplet
 | |
|         #     self.graph_store.upsert_triplet(*triplet)
 | |
|         #     self.graph_store.upsert_triplet(*triplet)
 | |
|         #     index_struct.add_node([subj, obj], text_node)
 |