diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py index 80cbab066..a50ebabdc 100644 --- a/pilot/graph_engine/graph_engine.py +++ b/pilot/graph_engine/graph_engine.py @@ -107,15 +107,26 @@ class RAGGraphEngine: """Build the index from nodes.""" index_struct = self.index_struct_cls() num_threads = 5 - chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) // num_threads + chunk_size = ( + len(documents) + if (len(documents) < num_threads) + else len(documents) // num_threads + ) import concurrent + future_tasks = [] with concurrent.futures.ThreadPoolExecutor() as executor: for i in range(num_threads): start = i * chunk_size end = start + chunk_size if i < num_threads - 1 else None - future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct)) + future_tasks.append( + executor.submit( + self._extract_triplets_task, + documents[start:end][0], + index_struct, + ) + ) result = [future.result() for future in future_tasks] return index_struct @@ -132,7 +143,6 @@ class RAGGraphEngine: # # return index_struct - def search(self, query): from pilot.graph_engine.graph_search import RAGGraphSearch @@ -141,6 +151,7 @@ class RAGGraphEngine: def _extract_triplets_task(self, doc, index_struct): import threading + thread_id = threading.get_ident() print(f"current thread-{thread_id} begin extract triplets task") triplets = self._extract_triplets(doc.page_content) @@ -148,7 +159,9 @@ class RAGGraphEngine: triplets = [] text_node = TextNode(text=doc.page_content, metadata=doc.metadata) logger.info(f"extracted knowledge triplets: {triplets}") - print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}") + print( + f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}" + ) for triplet in triplets: subj, _, obj = triplet self.graph_store.upsert_triplet(*triplet) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a1d6d9f08..10c89d620 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -107,6 +107,7 @@ class BaseChat(ABC): async def __call_base(self): import inspect + input_values = ( await self.generate_input_values() if inspect.isawaitable(self.generate_input_values()) diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 95949d319..7bba99c0a 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -258,9 +258,6 @@ class KnowledgeService: ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory ).create() rag_engine.knowledge_graph(docs=chunk_docs) - # docs = engine.search( - # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA" - # ) # update document status doc.status = SyncStatus.RUNNING.name doc.chunk_size = len(chunk_docs) diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index a8e126eb5..93816ea66 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -1,6 +1,7 @@ import os import logging -#import weaviate + +# import weaviate from langchain.schema import Document from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH