feat:rag graph

This commit is contained in:
aries_ckt 2023-10-16 21:14:20 +08:00
parent b63fa2dfe1
commit 68c9010e5c
3 changed files with 43 additions and 19 deletions

View File

@ -106,21 +106,50 @@ class RAGGraphEngine:
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
for doc in documents:
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
continue
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
num_threads = 5
chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) / num_threads
import concurrent
future_tasks = []
with concurrent.futures.ThreadPoolExecutor() as executor:
for i in range(num_threads):
start = i * chunk_size
end = start + chunk_size if i < num_threads - 1 else None
future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
result = [future.result() for future in future_tasks]
return index_struct
# for doc in documents:
# triplets = self._extract_triplets(doc.page_content)
# if len(triplets) == 0:
# continue
# text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
# logger.info(f"extracted knowledge triplets: {triplets}")
# for triplet in triplets:
# subj, _, obj = triplet
# self.graph_store.upsert_triplet(*triplet)
# index_struct.add_node([subj, obj], text_node)
#
# return index_struct
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)
def _extract_triplets_task(self, doc, index_struct):
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
if len(triplets) == 0:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)

View File

@ -107,10 +107,9 @@ class BaseChat(ABC):
async def __call_base(self):
import inspect
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
if inspect.isawaitable(self.generate_input_values)
else self.generate_input_values()
)
### Chat sequence advance
@ -181,7 +180,7 @@ class BaseChat(ABC):
span.end(metadata={"error": str(e)})
async def nostream_call(self):
payload = self.__call_base()
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(

View File

@ -1,11 +1,7 @@
import os
import json
import logging
import weaviate
#import weaviate
from langchain.schema import Document
from langchain.vectorstores import Weaviate
from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.vector_store.base import VectorStoreBase
@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
except WeaviateBaseError as e:
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False