mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 10:29:36 +00:00
feat:rag graph
This commit is contained in:
parent
b63fa2dfe1
commit
68c9010e5c
@ -106,21 +106,50 @@ class RAGGraphEngine:
|
|||||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||||
"""Build the index from nodes."""
|
"""Build the index from nodes."""
|
||||||
index_struct = self.index_struct_cls()
|
index_struct = self.index_struct_cls()
|
||||||
for doc in documents:
|
num_threads = 5
|
||||||
triplets = self._extract_triplets(doc.page_content)
|
chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) / num_threads
|
||||||
if len(triplets) == 0:
|
|
||||||
continue
|
|
||||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
|
||||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
|
||||||
for triplet in triplets:
|
|
||||||
subj, _, obj = triplet
|
|
||||||
self.graph_store.upsert_triplet(*triplet)
|
|
||||||
index_struct.add_node([subj, obj], text_node)
|
|
||||||
|
|
||||||
|
import concurrent
|
||||||
|
future_tasks = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
for i in range(num_threads):
|
||||||
|
start = i * chunk_size
|
||||||
|
end = start + chunk_size if i < num_threads - 1 else None
|
||||||
|
future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
|
||||||
|
|
||||||
|
result = [future.result() for future in future_tasks]
|
||||||
return index_struct
|
return index_struct
|
||||||
|
# for doc in documents:
|
||||||
|
# triplets = self._extract_triplets(doc.page_content)
|
||||||
|
# if len(triplets) == 0:
|
||||||
|
# continue
|
||||||
|
# text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||||
|
# logger.info(f"extracted knowledge triplets: {triplets}")
|
||||||
|
# for triplet in triplets:
|
||||||
|
# subj, _, obj = triplet
|
||||||
|
# self.graph_store.upsert_triplet(*triplet)
|
||||||
|
# index_struct.add_node([subj, obj], text_node)
|
||||||
|
#
|
||||||
|
# return index_struct
|
||||||
|
|
||||||
def search(self, query):
|
def search(self, query):
|
||||||
from pilot.graph_engine.graph_search import RAGGraphSearch
|
from pilot.graph_engine.graph_search import RAGGraphSearch
|
||||||
|
|
||||||
graph_search = RAGGraphSearch(graph_engine=self)
|
graph_search = RAGGraphSearch(graph_engine=self)
|
||||||
return graph_search.search(query)
|
return graph_search.search(query)
|
||||||
|
|
||||||
|
def _extract_triplets_task(self, doc, index_struct):
|
||||||
|
import threading
|
||||||
|
thread_id = threading.get_ident()
|
||||||
|
print(f"current thread-{thread_id} begin extract triplets task")
|
||||||
|
triplets = self._extract_triplets(doc.page_content)
|
||||||
|
if len(triplets) == 0:
|
||||||
|
triplets = []
|
||||||
|
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||||
|
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||||
|
print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
|
||||||
|
for triplet in triplets:
|
||||||
|
subj, _, obj = triplet
|
||||||
|
self.graph_store.upsert_triplet(*triplet)
|
||||||
|
self.graph_store.upsert_triplet(*triplet)
|
||||||
|
index_struct.add_node([subj, obj], text_node)
|
||||||
|
@ -107,10 +107,9 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
async def __call_base(self):
|
async def __call_base(self):
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
input_values = (
|
input_values = (
|
||||||
await self.generate_input_values()
|
await self.generate_input_values()
|
||||||
if inspect.isawaitable(self.generate_input_values())
|
if inspect.isawaitable(self.generate_input_values)
|
||||||
else self.generate_input_values()
|
else self.generate_input_values()
|
||||||
)
|
)
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
@ -181,7 +180,7 @@ class BaseChat(ABC):
|
|||||||
span.end(metadata={"error": str(e)})
|
span.end(metadata={"error": str(e)})
|
||||||
|
|
||||||
async def nostream_call(self):
|
async def nostream_call(self):
|
||||||
payload = self.__call_base()
|
payload = await self.__call_base()
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
span = root_tracer.start_span(
|
span = root_tracer.start_span(
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import weaviate
|
#import weaviate
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Weaviate
|
|
||||||
from weaviate.exceptions import WeaviateBaseError
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
@ -72,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
|
|||||||
if self.vector_store_client.schema.get(self.vector_name):
|
if self.vector_store_client.schema.get(self.vector_name):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except WeaviateBaseError as e:
|
except Exception as e:
|
||||||
logger.error("vector_name_exists error", e.message)
|
logger.error("vector_name_exists error", e.message)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user