style:fmt

This commit is contained in:
aries_ckt 2023-10-19 09:40:05 +08:00
parent f6694d95ec
commit aff0553b7e
4 changed files with 20 additions and 8 deletions

View File

@ -107,15 +107,26 @@ class RAGGraphEngine:
"""Build the index from nodes."""
index_struct = self.index_struct_cls()
num_threads = 5
chunk_size = len(documents) if (len(documents) < num_threads) else len(documents) // num_threads
chunk_size = (
len(documents)
if (len(documents) < num_threads)
else len(documents) // num_threads
)
import concurrent
future_tasks = []
with concurrent.futures.ThreadPoolExecutor() as executor:
for i in range(num_threads):
start = i * chunk_size
end = start + chunk_size if i < num_threads - 1 else None
future_tasks.append(executor.submit(self._extract_triplets_task, documents[start:end][0], index_struct))
future_tasks.append(
executor.submit(
self._extract_triplets_task,
documents[start:end][0],
index_struct,
)
)
result = [future.result() for future in future_tasks]
return index_struct
@ -132,7 +143,6 @@ class RAGGraphEngine:
#
# return index_struct
def search(self, query):
from pilot.graph_engine.graph_search import RAGGraphSearch
@ -141,6 +151,7 @@ class RAGGraphEngine:
def _extract_triplets_task(self, doc, index_struct):
import threading
thread_id = threading.get_ident()
print(f"current thread-{thread_id} begin extract triplets task")
triplets = self._extract_triplets(doc.page_content)
@ -148,7 +159,9 @@ class RAGGraphEngine:
triplets = []
text_node = TextNode(text=doc.page_content, metadata=doc.metadata)
logger.info(f"extracted knowledge triplets: {triplets}")
print(f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}")
print(
f"current thread-{thread_id} end extract triplets tasks, triplets-{triplets}"
)
for triplet in triplets:
subj, _, obj = triplet
self.graph_store.upsert_triplet(*triplet)

View File

@ -107,6 +107,7 @@ class BaseChat(ABC):
async def __call_base(self):
import inspect
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())

View File

@ -258,9 +258,6 @@ class KnowledgeService:
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
).create()
rag_engine.knowledge_graph(docs=chunk_docs)
# docs = engine.search(
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
# )
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)

View File

@ -1,6 +1,7 @@
import os
import logging
#import weaviate
# import weaviate
from langchain.schema import Document
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH