feat: add GraphRAG framework and integrate TuGraph (#1506)

Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Florian
2024-05-16 15:39:50 +08:00
committed by GitHub
parent 593e974405
commit a9087c3853
133 changed files with 10139 additions and 6631 deletions

View File

@@ -25,8 +25,9 @@ from dbgpt.app.knowledge.request.response import (
)
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.core import Chunk, LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.chunk_manager import ChunkParameters
@@ -39,7 +40,7 @@ from dbgpt.rag.text_splitter.text_splitter import (
)
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import Service, SyncStatus
from dbgpt.serve.rag.service.service import SyncStatus
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@@ -52,7 +53,6 @@ document_chunk_dao = DocumentChunkDao()
logger = logging.getLogger(__name__)
CFG = Config()
# default summary max iteration call with llm.
DEFAULT_SUMMARY_MAX_ITERATION = 5
# default summary concurrency call with llm.
@@ -70,6 +70,13 @@ class KnowledgeService:
def __init__(self):
pass
@property
def llm_client(self) -> LLMClient:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(worker_manager, True)
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
"""create knowledge space
Args:
@@ -332,16 +339,23 @@ class KnowledgeService:
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(
name=space_name,
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=self.model_name,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
@@ -442,21 +456,27 @@ class KnowledgeService:
Args:
- space_name: knowledge space name
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) == 0:
raise Exception(f"delete error, no space name:{space_name} in database")
spaces = knowledge_space_dao.get_knowledge_space(
KnowledgeSpaceEntity(name=space_name)
)
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vectors
vector_store_connector.delete_vector_name(space.name)
@@ -480,6 +500,12 @@ class KnowledgeService:
documents = knowledge_document_dao.get_documents(document_query)
if len(documents) != 1:
raise Exception(f"there are no or more than one document called {doc_name}")
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
vector_ids = documents[0].vector_ids
if vector_ids is not None:
embedding_factory = CFG.SYSTEM_APP.get_component(
@@ -488,10 +514,14 @@ class KnowledgeService:
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vector by ids
vector_store_connector.delete_by_ids(vector_ids)
@@ -535,7 +565,7 @@ class KnowledgeService:
"""
logger.info(
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}"
)
try:
with root_tracer.start_span(
@@ -645,3 +675,40 @@ class KnowledgeService:
**{"chat_param": chat_param},
)
return chat
def query_graph(self, space_name, limit):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
print(CFG.LLM_MODEL)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type, vector_store_config=config
)
graph = vector_store_connector.client.query_graph(limit=limit)
res = {"nodes": [], "edges": []}
for node in graph.vertices():
res["nodes"].append({"vid": node.vid})
for edge in graph.edges():
res["edges"].append(
{
"src": edge.sid,
"dst": edge.tid,
"label": edge.props[graph.edge_label],
}
)
return res