diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py index 04c4f54d9..e34baba79 100644 --- a/pilot/graph_engine/graph_engine.py +++ b/pilot/graph_engine/graph_engine.py @@ -106,21 +106,50 @@ class RAGGraphEngine: def _build_index_from_docs(self, documents: List[Document]) -> KG: """Build the index from nodes.""" index_struct = self.index_struct_cls() - for doc in documents: - triplets = self._extract_triplets(doc.page_content) - if len(triplets) == 0: - continue - text_node = TextNode(text=doc.page_content, metadata=doc.metadata) - logger.info(f"extracted knowledge triplets: {triplets}") - for triplet in triplets: - subj, _, obj = triplet - self.graph_store.upsert_triplet(*triplet) - index_struct.add_node([subj, obj], text_node) + num_threads = 5 + chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) / num_threads + import concurrent + future_tasks = [] + with concurrent.futures.ThreadPoolExecutor() as executor: + for i in range(num_threads): + start = i * chunk_size + end = start + chunk_size if i < num_threads - 1 else None + future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct)) + + result = [future.result() for future in future_tasks] return index_struct + # for doc in documents: + # triplets = self._extract_triplets(doc.page_content) + # if len(triplets) == 0: + # continue + # text_node = TextNode(text=doc.page_content, metadata=doc.metadata) + # logger.info(f"extracted knowledge triplets: {triplets}") + # for triplet in triplets: + # subj, _, obj = triplet + # self.graph_store.upsert_triplet(*triplet) + # index_struct.add_node([subj, obj], text_node) + # + # return index_struct def search(self, query): from pilot.graph_engine.graph_search import RAGGraphSearch graph_search = RAGGraphSearch(graph_engine=self) return graph_search.search(query) + + def _extract_triplets_task(self, doc, index_struct): + import threading + thread_id = threading.get_ident() + print(f"current thread-{thread_id} begin extract triplets task") + triplets = self._extract_triplets(doc.page_content) + if len(triplets) == 0: + triplets = [] + text_node = TextNode(text=doc.page_content, metadata=doc.metadata) + logger.info(f"extracted knowledge triplets: {triplets}") + print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}") + for triplet in triplets: + subj, _, obj = triplet + self.graph_store.upsert_triplet(*triplet) + self.graph_store.upsert_triplet(*triplet) + index_struct.add_node([subj, obj], text_node) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index bea00bde3..58a0becc9 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -107,10 +107,9 @@ class BaseChat(ABC): async def __call_base(self): import inspect - input_values = ( await self.generate_input_values() - if inspect.isawaitable(self.generate_input_values()) + if inspect.isawaitable(self.generate_input_values) else self.generate_input_values() ) ### Chat sequence advance @@ -181,7 +180,7 @@ class BaseChat(ABC): span.end(metadata={"error": str(e)}) async def nostream_call(self): - payload = self.__call_base() + payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" span = root_tracer.start_span( diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index 795cf21f9..a8e126eb5 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -1,11 +1,7 @@ import os -import json import logging -import weaviate +#import weaviate from langchain.schema import Document -from langchain.vectorstores import Weaviate -from weaviate.exceptions import WeaviateBaseError - from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.vector_store.base import VectorStoreBase @@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase): if self.vector_store_client.schema.get(self.vector_name): return True return False - except WeaviateBaseError as e: + except Exception as e: logger.error("vector_name_exists error", e.message) return False