mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
style:fmt
This commit is contained in:
parent
f6694d95ec
commit
aff0553b7e
@ -107,15 +107,26 @@ class RAGGraphEngine:
|
||||
"""Build the index from nodes."""
|
||||
index_struct = self.index_struct_cls()
|
||||
num_threads = 5
|
||||
chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) // num_threads
|
||||
chunk_size = (
|
||||
len(documents)
|
||||
if (len(documents) < num_threads)
|
||||
else len(documents) // num_threads
|
||||
)
|
||||
|
||||
import concurrent
|
||||
|
||||
future_tasks = []
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for i in range(num_threads):
|
||||
start = i * chunk_size
|
||||
end = start + chunk_size if i < num_threads - 1 else None
|
||||
future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
|
||||
future_tasks.append(
|
||||
executor.submit(
|
||||
self._extract_triplets_task,
|
||||
documents[start:end][0],
|
||||
index_struct,
|
||||
)
|
||||
)
|
||||
|
||||
result = [future.result() for future in future_tasks]
|
||||
return index_struct
|
||||
@ -132,7 +143,6 @@ class RAGGraphEngine:
|
||||
#
|
||||
# return index_struct
|
||||
|
||||
|
||||
def search(self, query):
|
||||
from pilot.graph_engine.graph_search import RAGGraphSearch
|
||||
|
||||
@ -141,6 +151,7 @@ class RAGGraphEngine:
|
||||
|
||||
def _extract_triplets_task(self, doc, index_struct):
|
||||
import threading
|
||||
|
||||
thread_id = threading.get_ident()
|
||||
print(f"current thread-{thread_id} begin extract triplets task")
|
||||
triplets = self._extract_triplets(doc.page_content)
|
||||
@ -148,7 +159,9 @@ class RAGGraphEngine:
|
||||
triplets = []
|
||||
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
|
||||
logger.info(f"extracted knowledge triplets: {triplets}")
|
||||
print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
|
||||
print(
|
||||
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
|
||||
)
|
||||
for triplet in triplets:
|
||||
subj, _, obj = triplet
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
|
@ -107,6 +107,7 @@ class BaseChat(ABC):
|
||||
|
||||
async def __call_base(self):
|
||||
import inspect
|
||||
|
||||
input_values = (
|
||||
await self.generate_input_values()
|
||||
if inspect.isawaitable(self.generate_input_values())
|
||||
|
@ -258,9 +258,6 @@ class KnowledgeService:
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
rag_engine.knowledge_graph(docs=chunk_docs)
|
||||
# docs = engine.search(
|
||||
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
||||
# )
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import logging
|
||||
#import weaviate
|
||||
|
||||
# import weaviate
|
||||
from langchain.schema import Document
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
|
Loading…
Reference in New Issue
Block a user