import logging from typing import Any, Optional, Callable, Tuple, List from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from pilot.embedding_engine import KnowledgeType from pilot.embedding_engine.knowledge_type import get_knowledge_embedding from pilot.graph_engine.index_struct import KG from pilot.graph_engine.node import TextNode from pilot.utils import utils logger = logging.getLogger(__name__) class RAGGraphEngine: """Knowledge RAG Graph Engine. Build a RAG Graph Client can extract triplets and insert into graph store. Args: knowledge_type (Optional[str]): Default: KnowledgeType.DOCUMENT.value extracting triplets. knowledge_source (Optional[str]): model_name (Optional[str]): llm model name graph_store (Optional[GraphStore]): The graph store to use.refrence:llama-index include_embeddings (bool): Whether to include embeddings in the index. Defaults to False. max_object_length (int): The maximum length of the object in a triplet. Defaults to 128. extract_triplet_fn (Optional[Callable]): The function to use for extracting triplets. Defaults to None. """ index_struct_cls = KG def __init__( self, knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value, knowledge_source: Optional[str] = None, text_splitter=None, graph_store=None, index_struct: Optional[KG] = None, model_name: Optional[str] = None, max_triplets_per_chunk: int = 10, include_embeddings: bool = False, max_object_length: int = 128, extract_triplet_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: """Initialize params.""" from llama_index.graph_stores import SimpleGraphStore # need to set parameters before building index in base class. self.knowledge_source = knowledge_source self.knowledge_type = knowledge_type self.model_name = model_name self.text_splitter = text_splitter self.index_struct = index_struct self.include_embeddings = include_embeddings self.graph_store = graph_store or SimpleGraphStore() # self.graph_store = graph_store self.max_triplets_per_chunk = max_triplets_per_chunk self._max_object_length = max_object_length self._extract_triplet_fn = extract_triplet_fn def knowledge_graph(self, docs=None): """knowledge docs into graph store""" if not docs: if self.text_splitter: self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=2000, chunk_overlap=100 ) knowledge_source = get_knowledge_embedding( knowledge_type=self.knowledge_type, knowledge_source=self.knowledge_source, text_splitter=self.text_splitter, ) docs = knowledge_source.read() if self.index_struct is None: self.index_struct = self._build_index_from_docs(docs) def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: """Extract triplets from text by function or llm""" if self._extract_triplet_fn is not None: return self._extract_triplet_fn(text) else: return self._llm_extract_triplets(text) def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: """Extract triplets from text by llm""" from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream import uuid chat_param = { "chat_session_id": uuid.uuid1(), "current_user_input": text, "select_param": "triplet", "model_name": self.model_name, } loop = utils.get_or_create_event_loop() triplets = loop.run_until_complete( llm_chat_response_nostream( ChatScene.ExtractTriplet.value(), **{"chat_param": chat_param} ) ) return triplets def _build_index_from_docs(self, documents: List[Document]) -> KG: """Build the index from nodes. Args:documents:List[Document] """ index_struct = self.index_struct_cls() triplets = [] for doc in documents: trips = self._extract_triplets_task([doc], index_struct) triplets.extend(trips) print(triplets) text_node = TextNode(text=doc.page_content, metadata=doc.metadata) for triplet in triplets: subj, _, obj = triplet self.graph_store.upsert_triplet(*triplet) index_struct.add_node([subj, obj], text_node) return index_struct # num_threads = 5 # chunk_size = ( # len(documents) # if (len(documents) < num_threads) # else len(documents) // num_threads # ) # # import concurrent # triples = [] # future_tasks = [] # with concurrent.futures.ThreadPoolExecutor() as executor: # for i in range(num_threads): # start = i * chunk_size # end = start + chunk_size if i < num_threads - 1 else None # # doc = documents[start:end] # future_tasks.append( # executor.submit( # self._extract_triplets_task, # documents[start:end], # index_struct, # ) # ) # # for doc in documents[start:end]: # # future_tasks.append( # # executor.submit( # # self._extract_triplets_task, # # doc, # # index_struct, # # ) # # ) # # # result = [future.result() for future in future_tasks] # completed_futures, _ = concurrent.futures.wait(future_tasks, return_when=concurrent.futures.ALL_COMPLETED) # for future in completed_futures: # # 获取已完成的future的结果并添加到results列表中 # result = future.result() # triplets.extend(result) # print(f"total triplets-{triples}") # for triplet in triplets: # subj, _, obj = triplet # self.graph_store.upsert_triplet(*triplet) # # index_struct.add_node([subj, obj], text_node) # return index_struct # for doc in documents: # triplets = self._extract_triplets(doc.page_content) # if len(triplets) == 0: # continue # text_node = TextNode(text=doc.page_content, metadata=doc.metadata) # logger.info(f"extracted knowledge triplets: {triplets}") # for triplet in triplets: # subj, _, obj = triplet # self.graph_store.upsert_triplet(*triplet) # index_struct.add_node([subj, obj], text_node) # # return index_struct def search(self, query): from pilot.graph_engine.graph_search import RAGGraphSearch graph_search = RAGGraphSearch(graph_engine=self) return graph_search.search(query) def _extract_triplets_task(self, docs, index_struct): triple_results = [] for doc in docs: import threading thread_id = threading.get_ident() print(f"current thread-{thread_id} begin extract triplets task") triplets = self._extract_triplets(doc.page_content) if len(triplets) == 0: triplets = [] text_node = TextNode(text=doc.page_content, metadata=doc.metadata) logger.info(f"extracted knowledge triplets: {triplets}") print( f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" ) triple_results.extend(triplets) return triple_results # for triplet in triplets: # subj, _, obj = triplet # self.graph_store.upsert_triplet(*triplet) # self.graph_store.upsert_triplet(*triplet) # index_struct.add_node([subj, obj], text_node)