feat: add GraphRAG framework and integrate TuGraph (#1506)

Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Florian 2024-05-16 15:39:50 +08:00 committed by GitHub
parent 593e974405
commit a9087c3853
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
133 changed files with 10139 additions and 6631 deletions

View File

@ -70,6 +70,7 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE=50
## Maximum number of chunks to load at once, if your single document is too large,
## you can set this value to a higher value for better performance.
## if out of memory when load large document, you can set this value to a lower value.
@ -138,10 +139,12 @@ LOCAL_DB_TYPE=sqlite
EXECUTE_LOCAL_COMMANDS=False
#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#** VECTOR STORE / KNOWLEDGE GRAPH SETTINGS **#
#*******************************************************************#
### Chroma vector db config
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
### Milvus vector db config
@ -163,6 +166,15 @@ ElasticSearch_PORT=9200
ElasticSearch_USERNAME=elastic
ElasticSearch_PASSWORD=i=+iLw9y0Jduq86XTi6W
### TuGraph config
#TUGRAPH_HOST=127.0.0.1
#TUGRAPH_PORT=7687
#TUGRAPH_USERNAME=admin
#TUGRAPH_PASSWORD=73@TuGraph
#TUGRAPH_VERTEX_TYPE=entity
#TUGRAPH_EDGE_TYPE=relation
#TUGRAPH_EDGE_NAME_KEY=label
#*******************************************************************#
#** WebServer Language Support **#
#*******************************************************************#

View File

@ -112,3 +112,6 @@ ignore_missing_imports = True
[mypy-ollama.*]
ignore_missing_imports = True
[mypy-networkx.*]
ignore_missing_imports = True

View File

@ -254,6 +254,9 @@ class Config(metaclass=Singleton):
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
self.KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE = int(
os.getenv("KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE", 50)
)
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
)

View File

@ -13,6 +13,7 @@ from dbgpt.app.knowledge.request.request import (
DocumentSummaryRequest,
DocumentSyncRequest,
EntityExtractRequest,
GraphVisRequest,
KnowledgeDocumentRequest,
KnowledgeQueryRequest,
KnowledgeSpaceRequest,
@ -75,7 +76,7 @@ def space_delete(request: KnowledgeSpaceRequest):
try:
return Result.succ(knowledge_space_service.delete_space(request.name))
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space delete error {e}")
@router.post("/knowledge/{space_name}/arguments")
@ -84,7 +85,7 @@ def arguments(space_name: str):
try:
return Result.succ(knowledge_space_service.arguments(space_name))
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space arguments error {e}")
@router.post("/knowledge/{space_name}/argument/save")
@ -95,7 +96,7 @@ def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
knowledge_space_service.argument_save(space_name, argument_request)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space save error {e}")
@router.post("/knowledge/{space_name}/document/add")
@ -156,6 +157,20 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
return Result.failed(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/graphvis")
def graph_vis(space_name: str, query_request: GraphVisRequest):
print(f"/document/list params: {space_name}, {query_request}")
print(query_request.limit)
try:
return Result.succ(
knowledge_space_service.query_graph(
space_name=space_name, limit=query_request.limit
)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"get graph vis error {e}")
@router.post("/knowledge/{space_name}/document/delete")
def document_delete(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
@ -164,7 +179,7 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
knowledge_space_service.delete_document(space_name, query_request.doc_name)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"document list error {e}")
return Result.failed(code="E000X", msg=f"document delete error {e}")
@router.post("/knowledge/{space_name}/document/upload")
@ -232,7 +247,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
@router.post("/knowledge/{space_name}/document/sync_batch")
def batch_document_sync(
async def batch_document_sync(
space_name: str,
request: List[KnowledgeSyncRequest],
service: Service = Depends(get_rag_service),
@ -242,13 +257,13 @@ def batch_document_sync(
space = service.get({"name": space_name})
for sync_request in request:
sync_request.space_id = space.id
doc_ids = service.sync_document(requests=request)
doc_ids = await service.sync_document(requests=request)
# doc_ids = service.sync_document(
# space_name=space_name, sync_requests=request
# )
return Result.succ({"tasks": doc_ids})
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")
return Result.failed(code="E000X", msg=f"document sync batch error {e}")
@router.post("/knowledge/{space_name}/chunk/list")

View File

@ -53,6 +53,10 @@ class DocumentQueryRequest(BaseModel):
page_size: int = 20
class GraphVisRequest(BaseModel):
limit: int = 100
class DocumentSyncRequest(BaseModel):
"""Sync request"""

View File

@ -25,8 +25,9 @@ from dbgpt.app.knowledge.request.response import (
)
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.core import Chunk, LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.chunk_manager import ChunkParameters
@ -39,7 +40,7 @@ from dbgpt.rag.text_splitter.text_splitter import (
)
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import Service, SyncStatus
from dbgpt.serve.rag.service.service import SyncStatus
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@ -52,7 +53,6 @@ document_chunk_dao = DocumentChunkDao()
logger = logging.getLogger(__name__)
CFG = Config()
# default summary max iteration call with llm.
DEFAULT_SUMMARY_MAX_ITERATION = 5
# default summary concurrency call with llm.
@ -70,6 +70,13 @@ class KnowledgeService:
def __init__(self):
pass
@property
def llm_client(self) -> LLMClient:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(worker_manager, True)
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
"""create knowledge space
Args:
@ -332,16 +339,23 @@ class KnowledgeService:
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(
name=space_name,
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=self.model_name,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
@ -442,21 +456,27 @@ class KnowledgeService:
Args:
- space_name: knowledge space name
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) == 0:
raise Exception(f"delete error, no space name:{space_name} in database")
spaces = knowledge_space_dao.get_knowledge_space(
KnowledgeSpaceEntity(name=space_name)
)
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vectors
vector_store_connector.delete_vector_name(space.name)
@ -480,6 +500,12 @@ class KnowledgeService:
documents = knowledge_document_dao.get_documents(document_query)
if len(documents) != 1:
raise Exception(f"there are no or more than one document called {doc_name}")
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
vector_ids = documents[0].vector_ids
if vector_ids is not None:
embedding_factory = CFG.SYSTEM_APP.get_component(
@ -488,10 +514,14 @@ class KnowledgeService:
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vector by ids
vector_store_connector.delete_by_ids(vector_ids)
@ -535,7 +565,7 @@ class KnowledgeService:
"""
logger.info(
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}"
)
try:
with root_tracer.start_span(
@ -645,3 +675,40 @@ class KnowledgeService:
**{"chat_param": chat_param},
)
return chat
def query_graph(self, space_name, limit):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
print(CFG.LLM_MODEL)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type, vector_store_config=config
)
graph = vector_store_connector.client.query_graph(limit=limit)
res = {"nodes": [], "edges": []}
for node in graph.vertices():
res["nodes"].append({"vid": node.vid})
for edge in graph.edges():
res["edges"].append(
{
"src": edge.sid,
"dst": edge.tid,
"label": edge.props[graph.edge_label],
}
)
return res

View File

@ -9,6 +9,7 @@ from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
@ -50,7 +51,7 @@ class ChatKnowledge(BaseChat):
)
self.space_context = self.get_space_context(self.knowledge_space)
self.top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
self.get_knowledge_search_top_size(self.knowledge_space)
if self.space_context is None
else int(self.space_context["embedding"]["topk"])
)
@ -73,12 +74,27 @@ class ChatKnowledge(BaseChat):
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
spaces = KnowledgeSpaceDao().get_knowledge_space(
KnowledgeSpaceEntity(name=self.knowledge_space)
)
if len(spaces) != 1:
raise Exception(f"invalid space name:{self.knowledge_space}")
space = spaces[0]
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
llm_model=self.llm_model,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
query_rewrite = None
if CFG.KNOWLEDGE_SEARCH_REWRITE:
@ -239,6 +255,18 @@ class ChatKnowledge(BaseChat):
service = KnowledgeService()
return service.get_space_context(space_name)
def get_knowledge_search_top_size(self, space_name) -> int:
service = KnowledgeService()
request = KnowledgeSpaceRequest(name=space_name)
spaces = service.get_knowledge_space(request)
if len(spaces) == 1:
from dbgpt.storage import vector_store
if spaces[0].vector_type in vector_store.__knowledge_graph__:
return CFG.KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE
return CFG.KNOWLEDGE_SEARCH_TOP_SIZE
async def execute_similar_search(self, query):
"""execute similarity search"""
with root_tracer.start_span(

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,a,c,e,t,n,f,k,d,h,i,u,j,b,p,o,g,l,r,_){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[p,s,a,e,c,f,h,k,o,"static/chunks/9305-f44429d5185a9fc7.js","static/chunks/1353-705aa47cc2b94999.js","static/chunks/pages/index-630820928b86def4.js"],"/_error":["static/chunks/pages/_error-8095ba9e1bf12f30.js"],"/agent":[s,a,c,t,h,n,"static/chunks/pages/agent-18e026fc6118647f.js"],"/app":[i,s,a,e,c,t,n,u,j,g,b,"static/chunks/pages/app-762a5a0b6bc328e9.js"],"/chat":["static/chunks/pages/chat-5794854c0948b84c.js"],"/database":[s,a,e,c,t,n,k,d,"static/chunks/3718-e111d727d432bdd2.js","static/chunks/pages/database-1218c76d7d4baf11.js"],"/flow":[i,s,a,c,u,j,g,b,"static/chunks/pages/flow-af517aa90c527538.js"],"/flow/canvas":[p,i,s,a,e,c,f,k,u,d,j,l,o,"static/chunks/1425-6e94ae18b1ac5a70.js",b,"static/chunks/pages/flow/canvas-e370c26a140e4220.js"],"/knowledge":[r,s,a,e,c,t,h,n,k,d,_,l,"static/chunks/2265-47013b51a7e3b2ef.js","static/chunks/pages/knowledge-793cde0e9a5a1e24.js"],"/knowledge/chunk":[s,e,t,f,n,"static/chunks/pages/knowledge/chunk-115c855fa8e7a2f9.js"],"/models":[r,s,a,e,c,d,"static/chunks/3444-30181eacc7980e66.js","static/chunks/pages/models-2fcfe92e4f548979.js"],"/prompt":[s,a,e,c,f,_,"static/chunks/7184-3ca3f58327a6986a.js","static/chunks/7869-1a99e25b182b3eaa.js","static/chunks/pages/prompt-d7d77828a490e5d4.js"],sortedPages:["/","/_app","/_error","/agent","/app","/chat","/database","/flow","/flow/canvas","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/2185-30f9d0578fa0d631.js","static/chunks/5503-c65f6d730754acc7.js","static/chunks/9479-21f588e1fd4e6b6d.js","static/chunks/1009-4b2af86bde623424.js","static/chunks/785-c3544abc036fc97d.js","static/chunks/5813-c6244a8eba7ef4ae.js","static/chunks/1647-8683da4db89d68c1.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/8928-0e78def492052d13.js","static/chunks/4553-61740188e6a650a8.js","static/chunks/971df74e-7436ff4085ebb785.js","static/chunks/7434-29506257e67e8077.js","static/chunks/9924-5bce555f07385e1f.js","static/css/b4846eed11c4725f.css","static/chunks/29107295-75edf0bf34e24b1e.js","static/chunks/2487-cda9d2a2fd712a15.js","static/chunks/6165-1c3685c948e36ea8.js","static/chunks/2282-96412afca1591c9a.js","static/chunks/75fc9c18-1d6133135d3d283c.js","static/chunks/5733-7ef320ab0f876a5e.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,c,a,e,t,n,f,k,d,h,i,u,j,b,p,g,o,l,r,_){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[p,s,c,e,a,f,h,k,g,"static/chunks/9305-f44429d5185a9fc7.js","static/chunks/1353-705aa47cc2b94999.js","static/chunks/pages/index-61ec52375a6b0d9d.js"],"/_error":["static/chunks/pages/_error-8095ba9e1bf12f30.js"],"/agent":[s,c,a,t,h,n,"static/chunks/pages/agent-2be7990da37f5165.js"],"/app":[i,s,c,e,a,t,n,u,j,o,b,"static/chunks/pages/app-8154f6fcced2f743.js"],"/chat":["static/chunks/pages/chat-5794854c0948b84c.js"],"/database":[s,c,e,a,t,n,k,d,"static/chunks/3718-e111d727d432bdd2.js","static/chunks/pages/database-a2bb591659fc4844.js"],"/flow":[i,s,c,a,u,j,o,b,"static/chunks/pages/flow-33fe9f396642fb4c.js"],"/flow/canvas":[p,i,s,c,e,a,f,k,u,d,j,l,g,"static/chunks/1425-6e94ae18b1ac5a70.js",b,"static/chunks/pages/flow/canvas-644b6ee718585173.js"],"/knowledge":[r,s,c,e,a,t,h,n,k,d,_,l,"static/chunks/5237-1d36a3742424b75e.js","static/chunks/pages/knowledge-a68197b9965a4f75.js"],"/knowledge/chunk":[s,e,t,f,n,"static/chunks/pages/knowledge/chunk-625a32aed5f380e2.js"],"/knowledge/graph":["static/chunks/90912e1b-ed32608ee46ab40f.js","static/chunks/193-5e83ce3fd4f165ef.js","static/chunks/pages/knowledge/graph-9fb1ec6bf06d5108.js"],"/models":[r,s,c,e,a,d,"static/chunks/3444-30181eacc7980e66.js","static/chunks/pages/models-2e49193aee9f674f.js"],"/prompt":[s,c,e,a,f,_,"static/chunks/7184-3ca3f58327a6986a.js","static/chunks/7869-1a99e25b182b3eaa.js","static/chunks/pages/prompt-c44ac718b4d637c9.js"],sortedPages:["/","/_app","/_error","/agent","/app","/chat","/database","/flow","/flow/canvas","/knowledge","/knowledge/chunk","/knowledge/graph","/models","/prompt"]}}("static/chunks/2185-30f9d0578fa0d631.js","static/chunks/5503-c65f6d730754acc7.js","static/chunks/9479-21f588e1fd4e6b6d.js","static/chunks/1009-4b2af86bde623424.js","static/chunks/785-c3544abc036fc97d.js","static/chunks/5813-c6244a8eba7ef4ae.js","static/chunks/1647-8683da4db89d68c1.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/8928-0e78def492052d13.js","static/chunks/4553-61740188e6a650a8.js","static/chunks/971df74e-7436ff4085ebb785.js","static/chunks/7434-29506257e67e8077.js","static/chunks/9924-5bce555f07385e1f.js","static/css/b4846eed11c4725f.css","static/chunks/29107295-75edf0bf34e24b1e.js","static/chunks/2487-cda9d2a2fd712a15.js","static/chunks/6165-1c3685c948e36ea8.js","static/chunks/2282-96412afca1591c9a.js","static/chunks/75fc9c18-1d6133135d3d283c.js","static/chunks/5733-7ef320ab0f876a5e.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,c,a,e,t,n,f,d,k,h,i,u,b,j,p,g,o,l,r,_){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[p,s,c,e,a,f,h,d,g,"static/chunks/9305-f44429d5185a9fc7.js","static/chunks/1353-705aa47cc2b94999.js","static/chunks/pages/index-f74d774146641b0c.js"],"/_error":["static/chunks/pages/_error-8095ba9e1bf12f30.js"],"/agent":[s,c,a,t,h,n,"static/chunks/pages/agent-2be7990da37f5165.js"],"/app":[i,s,c,e,a,t,n,u,b,o,j,"static/chunks/pages/app-8154f6fcced2f743.js"],"/chat":["static/chunks/pages/chat-5794854c0948b84c.js"],"/database":[s,c,e,a,t,n,d,k,"static/chunks/3718-e111d727d432bdd2.js","static/chunks/pages/database-06c1855c6863eef0.js"],"/flow":[i,s,c,a,u,b,o,j,"static/chunks/pages/flow-33fe9f396642fb4c.js"],"/flow/canvas":[p,i,s,c,e,a,f,d,u,k,b,l,g,"static/chunks/1425-6e94ae18b1ac5a70.js",j,"static/chunks/pages/flow/canvas-babc596dabfa6f92.js"],"/knowledge":[r,s,c,e,a,t,h,n,d,k,_,l,"static/chunks/5237-1d36a3742424b75e.js","static/chunks/pages/knowledge-28259a10b380ac05.js"],"/knowledge/chunk":[s,e,t,f,n,"static/chunks/pages/knowledge/chunk-625a32aed5f380e2.js"],"/knowledge/graph":["static/chunks/90912e1b-ed32608ee46ab40f.js","static/chunks/193-5e83ce3fd4f165ef.js","static/chunks/pages/knowledge/graph-9fb1ec6bf06d5108.js"],"/models":[r,s,c,e,a,k,"static/chunks/3444-30181eacc7980e66.js","static/chunks/pages/models-fa049a9c31ef32c4.js"],"/prompt":[s,c,e,a,f,_,"static/chunks/7184-3ca3f58327a6986a.js","static/chunks/7869-1a99e25b182b3eaa.js","static/chunks/pages/prompt-c44ac718b4d637c9.js"],sortedPages:["/","/_app","/_error","/agent","/app","/chat","/database","/flow","/flow/canvas","/knowledge","/knowledge/chunk","/knowledge/graph","/models","/prompt"]}}("static/chunks/2185-30f9d0578fa0d631.js","static/chunks/5503-c65f6d730754acc7.js","static/chunks/9479-21f588e1fd4e6b6d.js","static/chunks/1009-4b2af86bde623424.js","static/chunks/785-c3544abc036fc97d.js","static/chunks/5813-c6244a8eba7ef4ae.js","static/chunks/1647-8683da4db89d68c1.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/8928-0e78def492052d13.js","static/chunks/4553-61740188e6a650a8.js","static/chunks/971df74e-7436ff4085ebb785.js","static/chunks/7434-29506257e67e8077.js","static/chunks/9924-5bce555f07385e1f.js","static/css/b4846eed11c4725f.css","static/chunks/29107295-75edf0bf34e24b1e.js","static/chunks/2487-cda9d2a2fd712a15.js","static/chunks/6165-93d23bc520382b2c.js","static/chunks/2282-96412afca1591c9a.js","static/chunks/75fc9c18-1d6133135d3d283c.js","static/chunks/5733-7ef320ab0f876a5e.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,a,c,e,t,n,f,d,k,h,i,u,j,b,p,g,o,l,r,_){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[p,s,a,e,c,f,h,d,g,"static/chunks/9305-f44429d5185a9fc7.js","static/chunks/1353-705aa47cc2b94999.js","static/chunks/pages/index-61ec52375a6b0d9d.js"],"/_error":["static/chunks/pages/_error-8095ba9e1bf12f30.js"],"/agent":[s,a,c,t,h,n,"static/chunks/pages/agent-2be7990da37f5165.js"],"/app":[i,s,a,e,c,t,n,u,j,o,b,"static/chunks/pages/app-8154f6fcced2f743.js"],"/chat":["static/chunks/pages/chat-5794854c0948b84c.js"],"/database":[s,a,e,c,t,n,d,k,"static/chunks/3718-e111d727d432bdd2.js","static/chunks/pages/database-a2bb591659fc4844.js"],"/flow":[i,s,a,c,u,j,o,b,"static/chunks/pages/flow-33fe9f396642fb4c.js"],"/flow/canvas":[p,i,s,a,e,c,f,d,u,k,j,l,g,"static/chunks/1425-6e94ae18b1ac5a70.js",b,"static/chunks/pages/flow/canvas-644b6ee718585173.js"],"/knowledge":[r,s,a,e,c,t,h,n,d,k,_,l,"static/chunks/5237-1d36a3742424b75e.js","static/chunks/pages/knowledge-7e0b5aee1ffa72ab.js"],"/knowledge/chunk":[s,e,t,f,n,"static/chunks/pages/knowledge/chunk-625a32aed5f380e2.js"],"/knowledge/graph":["static/chunks/90912e1b-ed32608ee46ab40f.js","static/chunks/193-5e83ce3fd4f165ef.js","static/chunks/pages/knowledge/graph-d009ef9957f17d2c.js"],"/models":[r,s,a,e,c,k,"static/chunks/3444-30181eacc7980e66.js","static/chunks/pages/models-2e49193aee9f674f.js"],"/prompt":[s,a,e,c,f,_,"static/chunks/7184-3ca3f58327a6986a.js","static/chunks/7869-1a99e25b182b3eaa.js","static/chunks/pages/prompt-c44ac718b4d637c9.js"],sortedPages:["/","/_app","/_error","/agent","/app","/chat","/database","/flow","/flow/canvas","/knowledge","/knowledge/chunk","/knowledge/graph","/models","/prompt"]}}("static/chunks/2185-30f9d0578fa0d631.js","static/chunks/5503-c65f6d730754acc7.js","static/chunks/9479-21f588e1fd4e6b6d.js","static/chunks/1009-4b2af86bde623424.js","static/chunks/785-c3544abc036fc97d.js","static/chunks/5813-c6244a8eba7ef4ae.js","static/chunks/1647-8683da4db89d68c1.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/8928-0e78def492052d13.js","static/chunks/4553-61740188e6a650a8.js","static/chunks/971df74e-7436ff4085ebb785.js","static/chunks/7434-29506257e67e8077.js","static/chunks/9924-5bce555f07385e1f.js","static/css/b4846eed11c4725f.css","static/chunks/29107295-75edf0bf34e24b1e.js","static/chunks/2487-cda9d2a2fd712a15.js","static/chunks/6165-1c3685c948e36ea8.js","static/chunks/2282-96412afca1591c9a.js","static/chunks/75fc9c18-1d6133135d3d283c.js","static/chunks/5733-7ef320ab0f876a5e.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[7917],{99790:function(e,t,r){(window.__NEXT_P=window.__NEXT_P||[]).push(["/knowledge/graph",function(){return r(15219)}])},15219:function(e,t,r){"use strict";r.r(t);var l=r(85893),a=r(67294),n=r(1387),o=r(13840),i=r.n(o),s=r(71577),c=r(71965),u=r(89182),d=r(11163);n.Z.use(i());let f={name:"euler",springLength:340,fit:!1,springCoeff:8e-4,mass:20,dragCoeff:1,gravity:-20,pull:.009,randomize:!1,padding:0,maxIterations:1e3,maxSimulationTime:4e3};t.default=function(){let e=(0,a.useRef)(null),t=(0,d.useRouter)(),r=async()=>{let[t,r]=await (0,u.Vx)((0,u.FT)(p,{limit:500}));e.current&&r&&i(o(r))},o=e=>{let t=[],r=[];return e.nodes.forEach(e=>{let r={data:{id:e.vid,displayName:e.vid}};t.push(r)}),e.edges.forEach(e=>{let t={data:{id:e.src+"_"+e.dst+"_"+e.label,source:e.src,target:e.dst,displayName:e.label}};r.push(t)}),{nodes:t,edges:r}},i=t=>{let r=e.current,l=(0,n.Z)({container:e.current,elements:t,zoom:.3,pixelRatio:"auto",style:[{selector:"node",style:{width:60,height:60,color:"#fff","text-outline-color":"#37D4BE","text-outline-width":2,"text-valign":"center","text-halign":"center","background-color":"#37D4BE",label:"data(displayName)"}},{selector:"edge",style:{width:1,color:"#fff",label:"data(displayName)","line-color":"#66ADFF","font-size":14,"target-arrow-shape":"vee","control-point-step-size":40,"curve-style":"bezier","text-background-opacity":1,"text-background-color":"#66ADFF","target-arrow-color":"#66ADFF","text-background-shape":"roundrectangle","text-border-color":"#000","text-wrap":"wrap","text-valign":"top","text-halign":"center","text-background-padding":"5"}}]});l.layout(f).run(),l.pan({x:r.clientWidth/2,y:r.clientHeight/2})},{query:{spaceName:p}}=(0,d.useRouter)();return(0,a.useEffect)(()=>{p&&r()}),(0,l.jsxs)("div",{className:"p-4 h-full overflow-y-scroll relative px-2",children:[(0,l.jsx)("div",{children:(0,l.jsx)(s.ZP,{onClick:()=>{t.push("/knowledge")},icon:(0,l.jsx)(c.Z,{}),children:" Back "})}),(0,l.jsx)("div",{className:"h-full w-full",ref:e})]})}}},function(e){e.O(0,[9209,193,9774,2888,179],function(){return e(e.s=99790)}),_N_E=e.O()}]);

View File

@ -0,0 +1 @@
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[7917],{99790:function(e,t,r){(window.__NEXT_P=window.__NEXT_P||[]).push(["/knowledge/graph",function(){return r(15219)}])},15219:function(e,t,r){"use strict";r.r(t);var n=r(85893),l=r(67294),a=r(1387),o=r(13840),i=r.n(o),s=r(71577),c=r(71965),u=r(89182),d=r(11163);a.Z.use(i());let p={name:"euler",springLength:340,fit:!1,springCoeff:8e-4,mass:20,dragCoeff:1,gravity:-20,pull:.009,randomize:!1,padding:0,maxIterations:1e3,maxSimulationTime:4e3};t.default=function(){let e=(0,l.useRef)(null),t=(0,d.useRouter)(),r=async()=>{let t=await (0,u.FT)(f,{limit:500});e.current&&i(o(t))},o=e=>{let t=[],r=[];return e.nodes.forEach(e=>{let r={data:{id:e.identity,displayName:e.properties.id}};t.push(r)}),e.relationships.forEach(e=>{let t={data:{id:e.src+"_"+e.dst,source:e.src,target:e.dst,displayName:e.properties.id}};r.push(t)}),{nodes:t,edges:r}},i=t=>{let r=e.current,n=(0,a.Z)({container:e.current,elements:t,zoom:.5,style:[{selector:"node",style:{width:60,height:60,color:"#fff","text-outline-color":"#37D4BE","text-outline-width":2,"text-valign":"center","text-halign":"center","background-color":"#37D4BE",label:"data(displayName)"}},{selector:"edge",style:{width:1,color:"#fff",label:"data(displayName)","line-color":"#66ADFF","font-size":14,"target-arrow-shape":"vee","control-point-step-size":40,"curve-style":"bezier","text-background-opacity":1,"text-background-color":"#66ADFF","target-arrow-color":"#66ADFF","text-background-shape":"roundrectangle","text-border-color":"#000","text-wrap":"wrap","text-valign":"top","text-halign":"center","text-background-padding":"5"}}]});n.layout(p).run(),n.pan({x:r.clientWidth/2,y:r.clientHeight/2})},{query:{spaceName:f}}=(0,d.useRouter)();return(0,l.useEffect)(()=>{f&&r()}),(0,n.jsxs)("div",{className:"p-4 h-full overflow-y-scroll relative px-2",children:[(0,n.jsx)("div",{children:(0,n.jsx)(s.ZP,{onClick:()=>{t.push("/knowledge")},icon:(0,n.jsx)(c.Z,{}),children:" Back "})}),(0,n.jsx)("div",{className:"h-full w-full",ref:e})]})}}},function(e){e.O(0,[9209,193,9774,2888,179],function(){return e(e.s=99790)}),_N_E=e.O()}]);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,c,a,e,t,n,d,f,k,h,i,u,b,j,p,g,o,l,r,_){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":[p,s,c,e,a,d,h,f,g,"static/chunks/9305-f44429d5185a9fc7.js","static/chunks/1353-705aa47cc2b94999.js","static/chunks/pages/index-61ec52375a6b0d9d.js"],"/_error":["static/chunks/pages/_error-8095ba9e1bf12f30.js"],"/agent":[s,c,a,t,h,n,"static/chunks/pages/agent-2be7990da37f5165.js"],"/app":[i,s,c,e,a,t,n,u,b,o,j,"static/chunks/pages/app-8154f6fcced2f743.js"],"/chat":["static/chunks/pages/chat-5794854c0948b84c.js"],"/database":[s,c,e,a,t,n,f,k,"static/chunks/3718-e111d727d432bdd2.js","static/chunks/pages/database-7384ab94b08f23ff.js"],"/flow":[i,s,c,a,u,b,o,j,"static/chunks/pages/flow-33fe9f396642fb4c.js"],"/flow/canvas":[p,i,s,c,e,a,d,f,u,k,b,l,g,"static/chunks/1425-6e94ae18b1ac5a70.js",j,"static/chunks/pages/flow/canvas-644b6ee718585173.js"],"/knowledge":[r,s,c,e,a,t,h,n,f,k,_,l,"static/chunks/5237-1d36a3742424b75e.js","static/chunks/pages/knowledge-8c97044c36508d2d.js"],"/knowledge/chunk":[s,e,t,d,n,"static/chunks/pages/knowledge/chunk-625a32aed5f380e2.js"],"/knowledge/graph":["static/chunks/90912e1b-ed32608ee46ab40f.js","static/chunks/193-5e83ce3fd4f165ef.js","static/chunks/pages/knowledge/graph-9fb1ec6bf06d5108.js"],"/models":[r,s,c,e,a,k,"static/chunks/3444-30181eacc7980e66.js","static/chunks/pages/models-446238c56e41aa1b.js"],"/prompt":[s,c,e,a,d,_,"static/chunks/7184-3ca3f58327a6986a.js","static/chunks/7869-1a99e25b182b3eaa.js","static/chunks/pages/prompt-c44ac718b4d637c9.js"],sortedPages:["/","/_app","/_error","/agent","/app","/chat","/database","/flow","/flow/canvas","/knowledge","/knowledge/chunk","/knowledge/graph","/models","/prompt"]}}("static/chunks/2185-30f9d0578fa0d631.js","static/chunks/5503-c65f6d730754acc7.js","static/chunks/9479-21f588e1fd4e6b6d.js","static/chunks/1009-4b2af86bde623424.js","static/chunks/785-c3544abc036fc97d.js","static/chunks/5813-c6244a8eba7ef4ae.js","static/chunks/1647-8683da4db89d68c1.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/8928-0e78def492052d13.js","static/chunks/4553-61740188e6a650a8.js","static/chunks/971df74e-7436ff4085ebb785.js","static/chunks/7434-29506257e67e8077.js","static/chunks/9924-5bce555f07385e1f.js","static/css/b4846eed11c4725f.css","static/chunks/29107295-75edf0bf34e24b1e.js","static/chunks/2487-cda9d2a2fd712a15.js","static/chunks/6165-93d23bc520382b2c.js","static/chunks/2282-96412afca1591c9a.js","static/chunks/75fc9c18-1d6133135d3d283c.js","static/chunks/5733-7ef320ab0f876a5e.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__SSG_MANIFEST=new Set([]);self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
"""TuGraph Connector."""
import json
from typing import Any, Dict, List, cast
from typing import Dict, List, cast
from .base import BaseConnector
@ -12,37 +13,60 @@ class TuGraphConnector(BaseConnector):
driver: str = "bolt"
dialect: str = "tugraph"
def __init__(self, session):
def __init__(self, driver, graph):
"""Initialize the connector with a Neo4j driver."""
self._session = session
self._driver = driver
self._schema = None
self._graph = graph
self._session = None
def create_graph(self, graph_name: str) -> None:
"""Create a new graph."""
# run the query to get vertex labels
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
if not exists:
session.run(f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)")
def delete_graph(self, graph_name: str) -> None:
"""Delete a graph."""
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
if exists:
session.run(f"Call dbms.graph.deleteGraph('{graph_name}')")
@classmethod
def from_uri_db(
cls, host: str, port: int, user: str, pwd: str, db_name: str, **kwargs: Any
cls, host: str, port: int, user: str, pwd: str, db_name: str
) -> "TuGraphConnector":
"""Create a new TuGraphConnector from host, port, user, pwd, db_name."""
try:
from neo4j import GraphDatabase
db_url = f"{cls.driver}://{host}:{str(port)}"
with GraphDatabase.driver(db_url, auth=(user, pwd)) as client:
client.verify_connectivity()
session = client.session(database=db_name)
return cast(TuGraphConnector, cls(session=session))
driver = GraphDatabase.driver(db_url, auth=(user, pwd))
driver.verify_connectivity()
return cast(TuGraphConnector, cls(driver=driver, graph=db_name))
except ImportError as err:
raise ImportError("requests package is not installed") from err
raise ImportError(
"neo4j package is not installed, please install it with "
"`pip install neo4j`"
) from err
def get_table_names(self) -> Dict[str, List[str]]:
"""Get all table names from the TuGraph database using the Neo4j driver."""
# Run the query to get vertex labels
v_result = self._session.run("CALL db.vertexLabels()").data()
v_data = [table_name["label"] for table_name in v_result]
"""Get all table names from the TuGraph by Neo4j driver."""
# run the query to get vertex labels
with self._driver.session(database=self._graph) as session:
v_result = session.run("CALL db.vertexLabels()").data()
v_data = [table_name["label"] for table_name in v_result]
# Run the query to get edge labels
e_result = self._session.run("CALL db.edgeLabels()").data()
e_data = [table_name["label"] for table_name in e_result]
return {"vertex_tables": v_data, "edge_tables": e_data}
# run the query to get edge labels
e_result = session.run("CALL db.edgeLabels()").data()
e_data = [table_name["label"] for table_name in e_result]
return {"vertex_tables": v_data, "edge_tables": e_data}
def get_grants(self):
"""Get grants."""
@ -62,11 +86,13 @@ class TuGraphConnector(BaseConnector):
def close(self):
"""Close the Neo4j driver."""
self._session.close()
self._driver.close()
def run(self):
def run(self, query: str, fetch: str = "all") -> List:
"""Run GQL."""
return []
with self._driver.session(database=self._graph) as session:
result = session.run(query)
return list(result)
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
@ -80,27 +106,27 @@ class TuGraphConnector(BaseConnector):
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
data = []
result = None
if table_type == "vertex":
result = self._session.run(
f"CALL db.getVertexSchema('{table_name}')"
).data()
else:
result = self._session.run(f"CALL db.getEdgeSchema('{table_name}')").data()
schema_info = json.loads(result[0]["schema"])
for prop in schema_info.get("properties", []):
prop_dict = {
"name": prop["name"],
"type": prop["type"],
"default_expression": "",
"is_in_primary_key": bool(
"primary" in schema_info and prop["name"] == schema_info["primary"]
),
"comment": prop["name"],
}
data.append(prop_dict)
return data
with self._driver.session(database=self._graph) as session:
data = []
result = None
if table_type == "vertex":
result = session.run(f"CALL db.getVertexSchema('{table_name}')").data()
else:
result = session.run(f"CALL db.getEdgeSchema('{table_name}')").data()
schema_info = json.loads(result[0]["schema"])
for prop in schema_info.get("properties", []):
prop_dict = {
"name": prop["name"],
"type": prop["type"],
"default_expression": "",
"is_in_primary_key": bool(
"primary" in schema_info
and prop["name"] == schema_info["primary"]
),
"comment": prop["name"],
}
data.append(prop_dict)
return data
def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get table indexes about specified table.
@ -112,14 +138,15 @@ class TuGraphConnector(BaseConnector):
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
# [{'name':'id','column_names':['id']}]
result = self._session.run(
f"CALL db.listLabelIndexes('{table_name}','{table_type}')"
).data()
transformed_data = []
for item in result:
new_dict = {"name": item["field"], "column_names": [item["field"]]}
transformed_data.append(new_dict)
return transformed_data
with self._driver.session(database=self._graph) as session:
result = session.run(
f"CALL db.listLabelIndexes('{table_name}','{table_type}')"
).data()
transformed_data = []
for item in result:
new_dict = {"name": item["field"], "column_names": [item["field"]]}
transformed_data.append(new_dict)
return transformed_data
@classmethod
def is_graph_type(cls) -> bool:

View File

@ -1,4 +1,5 @@
"""Base Assembler."""
from abc import ABC, abstractmethod
from typing import Any, List, Optional
@ -37,13 +38,15 @@ class BaseAssembler(ABC):
)
self._chunks: List[Chunk] = []
metadata = {
"knowledge_cls": self._knowledge.__class__.__name__
if self._knowledge
else None,
"knowledge_cls": (
self._knowledge.__class__.__name__ if self._knowledge else None
),
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
"path": self._knowledge._path
if self._knowledge and hasattr(self._knowledge, "_path")
else None,
"path": (
self._knowledge._path
if self._knowledge and hasattr(self._knowledge, "_path")
else None
),
"chunk_parameters": self._chunk_parameters.dict(),
}
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
@ -70,6 +73,14 @@ class BaseAssembler(ABC):
List[str]: List of persisted chunk ids.
"""
async def apersist(self) -> List[str]:
"""Persist chunks.
Returns:
List[str]: List of persisted chunk ids.
"""
raise NotImplementedError
def get_chunks(self) -> List[Chunk]:
"""Return chunks."""
return self._chunks

View File

@ -106,6 +106,14 @@ class EmbeddingAssembler(BaseAssembler):
"""
return self._vector_store_connector.load_document(self._chunks)
async def apersist(self) -> List[str]:
"""Persist chunks into store.
Returns:
List[str]: List of chunk ids.
"""
return await self._vector_store_connector.aload_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []

View File

@ -0,0 +1 @@
"""Module for index."""

168
dbgpt/rag/index/base.py Normal file
View File

@ -0,0 +1,168 @@
"""Index store base class."""
import logging
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.filters import MetadataFilters
logger = logging.getLogger(__name__)
class IndexStoreConfig(BaseModel):
"""Index store config."""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = Field(
default="dbgpt_collection",
description="The name of index store, if not set, will use the default name.",
)
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
description="The max number of chunks to load at once. If your document is "
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)
max_threads: int = Field(
default=1,
description="The max number of threads to use. Default is 1. If you set this "
"bigger than 1, please make sure your vector store is thread-safe.",
)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, **kwargs)
class IndexStoreBase(ABC):
"""Index store base class."""
@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
@abstractmethod
def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
@abstractmethod
def similar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar search with scores in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
@abstractmethod
def delete_by_ids(self, ids: str):
"""Delete docs.
Args:
ids(str): The vector ids to delete, separated by comma.
"""
@abstractmethod
def delete_vector_name(self, index_name: str):
"""Delete index by name.
Args:
index_name(str): The name of index to delete.
"""
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""Load document in index database with specified limit.
Args:
chunks(List[Chunk]): Document chunks.
max_chunks_once_load(int): Max number of chunks to load at once.
max_threads(int): Max number of threads to use.
Return:
List[str]: Chunk ids.
"""
# Group the chunks into chunks of size max_chunks
chunk_groups = [
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)
]
logger.info(
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
f"{max_threads} threads."
)
ids = []
loaded_cnt = 0
start_time = time.time()
with ThreadPoolExecutor(max_workers=max_threads) as executor:
tasks = []
for chunk_group in chunk_groups:
tasks.append(executor.submit(self.load_document, chunk_group))
for future in tasks:
success_ids = future.result()
ids.extend(success_ids)
loaded_cnt += len(success_ids)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
)
return ids
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
return self.similar_search_with_scores(text, topk, 1.0, filters)
async def asimilar_search_with_scores(
self,
doc: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Aynsc similar_search_with_score in vector database."""
return self.similar_search_with_scores(doc, topk, score_threshold, filters)

View File

@ -229,6 +229,6 @@ class EmbeddingRetriever(BaseRetriever):
self, query, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search with score."""
return self._vector_store_connector.similar_search_with_scores(
return await self._vector_store_connector.asimilar_search_with_scores(
query, self._top_k, score_threshold, filters
)

View File

@ -0,0 +1 @@
"""Module for transformer."""

View File

@ -0,0 +1,26 @@
"""Transformer base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
logger = logging.getLogger(__name__)
class TransformerBase:
"""Transformer base class."""
class EmbedderBase(TransformerBase, ABC):
"""Embedder base class."""
class ExtractorBase(TransformerBase, ABC):
"""Extractor base class."""
@abstractmethod
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""
class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""

View File

@ -0,0 +1,50 @@
"""KeywordExtractor class."""
import logging
from typing import List, Optional
from dbgpt.core import LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
KEYWORD_EXTRACT_PT = (
"A question is provided below. Given the question, extract up to "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question.\n"
"Generate as more as possible synonyms or alias of the keywords "
"considering possible cases of capitalization, pluralization, "
"common expressions, etc.\n"
"Avoid stopwords.\n"
"Provide the keywords and synonyms in comma-separated format."
"Formatted keywords and synonyms text should be separated by a semicolon.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
"Keywords:\nAlice,mother,Bob;mummy\n"
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
"Keywords:\nPhilz,coffee shop,Berkeley,1982;coffee bar,coffee house\n"
"---------------------\n"
"Text: {text}\n"
"Keywords:\n"
)
logger = logging.getLogger(__name__)
class KeywordExtractor(LLMExtractor):
"""KeywordExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the KeywordExtractor."""
super().__init__(llm_client, model_name, KEYWORD_EXTRACT_PT)
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
keywords = set()
for part in text.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)
return list(keywords)

View File

@ -0,0 +1,50 @@
"""TripletExtractor class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import ExtractorBase
logger = logging.getLogger(__name__)
class LLMExtractor(ExtractorBase, ABC):
"""LLMExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
"""Initialize the LLMExtractor."""
self._llm_client = llm_client
self._model_name = model_name
self._prompt_template = prompt_template
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLm."""
template = HumanPromptTemplate.from_template(self._prompt_template)
messages = template.format_messages(text=text)
# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model
logger.info(f"Using model {self._model_name} to extract")
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)
if not response.success:
code = str(response.error_code)
reason = response.text
logger.error(f"request llm failed ({code}) {reason}")
return []
if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)
@abstractmethod
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
"""Parse llm response."""

View File

@ -0,0 +1,10 @@
"""Text2Cypher class."""
import logging
from dbgpt.rag.transformer.base import TranslatorBase
logger = logging.getLogger(__name__)
class Text2Cypher(TranslatorBase):
"""Text2Cypher class."""

View File

@ -0,0 +1,10 @@
"""Text2GQL class."""
import logging
from dbgpt.rag.transformer.base import TranslatorBase
logger = logging.getLogger(__name__)
class Text2GQL(TranslatorBase):
"""Text2GQL class."""

View File

@ -0,0 +1,10 @@
"""Text2Vector class."""
import logging
from dbgpt.rag.transformer.base import EmbedderBase
logger = logging.getLogger(__name__)
class Text2Vector(EmbedderBase):
"""Text2Vector class."""

View File

@ -0,0 +1,71 @@
"""TripletExtractor class."""
import logging
import re
from typing import Any, List, Optional, Tuple
from dbgpt.core import LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
logger = logging.getLogger(__name__)
TRIPLET_EXTRACT_PT = (
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
"Triplets:\n(Alice, is mother of, Bob)\n"
"Text: Alice has 2 apples.\n"
"Triplets:\n(Alice, has 2, apple)\n"
"Text: Alice was given 1 apple by Bob.\n"
"Triplets:(Bob, gives 1 apple, Bob)\n"
"Text: Alice was pushed by Bob.\n"
"Triplets:(Bob, pushes, Alice)\n"
"Text: Bob's mother Alice has 2 apples.\n"
"Triplets:\n(Alice, is mother of, Bob)\n(Alice, has 2, apple)\n"
"Text: A Big monkey climbed up the tall fruit tree and picked 3 peaches.\n"
"Triplets:\n(monkey, climbed up, fruit tree)\n(monkey, picked 3, peach)\n"
"Text: Alice has 2 apples, she gives 1 to Bob.\n"
"Triplets:\n"
"(Alice, has 2, apple)\n(Alice, gives 1 apple, Bob)\n"
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
"Triplets:\n"
"(Philz, is, coffee shop)\n(Philz, founded in, Berkeley)\n"
"(Philz, founded in, 1982)\n"
"---------------------\n"
"Text: {text}\n"
"Triplets:\n"
)
class TripletExtractor(LLMExtractor):
"""TripletExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the TripletExtractor."""
super().__init__(llm_client, model_name, TRIPLET_EXTRACT_PT)
def _parse_response(
self, text: str, limit: Optional[int] = None
) -> List[Tuple[Any, ...]]:
triplets = []
for line in text.split("\n"):
for match in re.findall(r"\((.*?)\)", line):
splits = match.split(",")
parts = [split.strip() for split in splits if split.strip()]
if len(parts) == 3:
parts = [
p.strip(
"`~!@#$%^&*()-=+[]\\{}|;':\",./<>?"
"·!¥&*()—【】、「」;‘’:“”,。、《》?"
)
for p in parts
]
triplets.append(tuple(parts))
if limit and len(triplets) >= limit:
return triplets
return triplets

View File

@ -1,3 +1,4 @@
import asyncio
import json
import logging
import os
@ -21,8 +22,10 @@ from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.core import Chunk
from dbgpt.core import Chunk, LLMClient
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding import EmbeddingFactory
@ -71,7 +74,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
document_dao: Optional[KnowledgeDocumentDao] = None,
chunk_dao: Optional[DocumentChunkDao] = None,
):
self._system_app = None
self._system_app = system_app
self._dao: KnowledgeSpaceDao = dao
self._document_dao: KnowledgeDocumentDao = document_dao
self._chunk_dao: DocumentChunkDao = chunk_dao
@ -112,6 +115,13 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
"""Returns the internal ServeConfig."""
return self._serve_config
@property
def llm_client(self) -> LLMClient:
worker_manager = self._system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(worker_manager, True)
def create_space(self, request: SpaceServeRequest) -> SpaceServeResponse:
"""Create a new Space entity
@ -198,7 +208,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
raise Exception(f"create document failed, {request.doc_name}")
return doc_id
def sync_document(self, requests: List[KnowledgeSyncRequest]) -> List:
async def sync_document(self, requests: List[KnowledgeSyncRequest]) -> List:
"""Create a new document entity
Args:
@ -236,7 +246,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_id, doc, chunk_parameters)
await self._sync_knowledge_document(space_id, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids
@ -284,10 +294,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
space = self.get(query_request)
if space is None:
raise HTTPException(status_code=400, detail=f"Space {space_id} not found")
config = VectorStoreConfig(name=space.name)
config = VectorStoreConfig(
name=space.name, llm_client=self.llm_client, model_name=None
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vectors
vector_store_connector.delete_vector_name(space.name)
@ -316,12 +327,22 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
docuemnt = self._document_dao.get_one(query_request)
if docuemnt is None:
raise Exception(f"there are no or more than one document {document_id}")
# get space by name
spaces = self._dao.get_knowledge_space(
KnowledgeSpaceEntity(name=docuemnt.space)
)
if len(spaces) != 1:
raise Exception(f"invalid space name: {docuemnt.space}")
space = spaces[0]
vector_ids = docuemnt.vector_ids
if vector_ids is not None:
config = VectorStoreConfig(name=docuemnt.space)
config = VectorStoreConfig(
name=space.name, llm_client=self.llm_client, model_name=None
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
# delete vector by ids
vector_store_connector.delete_by_ids(vector_ids)
@ -375,7 +396,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
"""
return self._document_dao.get_list_page(request, page, page_size)
def _batch_document_sync(
async def _batch_document_sync(
self, space_id, sync_requests: List[KnowledgeSyncRequest]
) -> List[int]:
"""batch sync knowledge document chunk into vector store
@ -413,11 +434,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_id, doc, chunk_parameters)
await self._sync_knowledge_document(space_id, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids
def _sync_knowledge_document(
async def _sync_knowledge_document(
self,
space_id,
doc_vo: DocumentVO,
@ -439,10 +460,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type, vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
@ -458,15 +480,16 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
self._document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
# executor = CFG.SYSTEM_APP.get_component(
# ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
# ).create()
# executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc))
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs
@trace("async_doc_embedding")
def async_doc_embedding(self, assembler, chunk_docs, doc):
async def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
@ -475,14 +498,19 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
"""
logger.info(
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}"
)
try:
with root_tracer.start_span(
"app.knowledge.assembler.persist",
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
):
vector_ids = assembler.persist()
# vector_ids = assembler.persist()
space = self.get({"name": doc.space})
if space and space.vector_type == "KnowledgeGraph":
vector_ids = await assembler.apersist()
else:
vector_ids = assembler.persist()
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:

View File

@ -0,0 +1,37 @@
"""Graph Store Module."""
from typing import Tuple, Type
def _import_memgraph() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore,
MemoryGraphStoreConfig,
)
return MemoryGraphStore, MemoryGraphStoreConfig
def _import_tugraph() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
return TuGraphStore, TuGraphStoreConfig
def _import_neo4j() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.neo4j_store import Neo4jStore, Neo4jStoreConfig
return Neo4jStore, Neo4jStoreConfig
def __getattr__(name: str) -> Tuple[Type, Type]:
if name == "Memory":
return _import_memgraph()
elif name == "TuGraph":
return _import_tugraph()
elif name == "Neo4j":
return _import_neo4j()
else:
raise AttributeError(f"Could not find: {name}")
__all__ = ["Memory", "TuGraph", "Neo4j"]

View File

@ -0,0 +1,68 @@
"""Graph store base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.storage.graph_store.graph import Direction, Graph
logger = logging.getLogger(__name__)
class GraphStoreConfig(BaseModel):
"""Graph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = Field(
default="dbgpt_collection",
description="The name of graph store, inherit from index store.",
)
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of graph store, optional.",
)
class GraphStoreBase(ABC):
"""Graph store base class."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""
@abstractmethod
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
@abstractmethod
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplet."""
@abstractmethod
def drop(self):
"""Drop graph."""
@abstractmethod
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
@abstractmethod
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
@abstractmethod
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore on graph."""
@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""

View File

@ -0,0 +1,42 @@
"""Connector for vector store."""
import logging
from typing import Tuple, Type
from dbgpt.storage import graph_store
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
logger = logging.getLogger(__name__)
class GraphStoreFactory:
"""Factory for graph store."""
@staticmethod
def create(graph_store_type: str, graph_store_configure=None) -> GraphStoreBase:
"""Create a GraphStore instance.
Args:
- graph_store_type: graph store type Memory, TuGraph, Neo4j
- graph_store_config: graph store config
"""
store_cls, cfg_cls = GraphStoreFactory.__find_type(graph_store_type)
try:
config = cfg_cls()
if graph_store_configure:
graph_store_configure(config)
return store_cls(config)
except Exception as e:
logger.error("create graph store failed: %s", e)
raise e
@staticmethod
def __find_type(graph_store_type: str) -> Tuple[Type, Type]:
for t in graph_store.__all__:
if t.lower() == graph_store_type.lower():
store_cls, cfg_cls = getattr(graph_store, t)
if issubclass(store_cls, GraphStoreBase) and issubclass(
cfg_cls, GraphStoreConfig
):
return store_cls, cfg_cls
raise Exception(f"Graph store {graph_store_type} not supported")

View File

@ -0,0 +1,477 @@
"""Graph store base class."""
import itertools
import json
import logging
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import networkx as nx
logger = logging.getLogger(__name__)
class Direction(Enum):
"""Direction class."""
OUT = 0
IN = 1
BOTH = 2
class Elem(ABC):
"""Elem class."""
def __init__(self):
"""Initialize Elem."""
self._props = {}
@property
def props(self) -> Dict[str, Any]:
"""Get all the properties of Elem."""
return self._props
def set_prop(self, key: str, value: Any):
"""Set a property of ELem."""
self._props[key] = value
def get_prop(self, key: str):
"""Get one of the properties of Elem."""
return self._props.get(key)
def del_prop(self, key: str):
"""Delete a property of ELem."""
self._props.pop(key, None)
def has_props(self, **props):
"""Check if the element has the specified properties with the given values."""
return all(self._props.get(k) == v for k, v in props.items())
@abstractmethod
def format(self, label_key: Optional[str] = None):
"""Format properties into a string."""
formatted_props = [
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
]
return f"{{{';'.join(formatted_props)}}}"
class Vertex(Elem):
"""Vertex class."""
def __init__(self, vid: str, **props):
"""Initialize Vertex."""
super().__init__()
self._vid = vid
for k, v in props.items():
self.set_prop(k, v)
@property
def vid(self) -> str:
"""Return the vertex ID."""
return self._vid
def format(self, label_key: Optional[str] = None):
"""Format vertex properties into a string."""
label = self.get_prop(label_key) if label_key else self._vid
props_str = super().format(label_key)
if props_str == "{}":
return f"({label})"
else:
return f"({label}:{props_str})"
def __str__(self):
"""Return the vertex ID as its string representation."""
return f"({self._vid})"
class Edge(Elem):
"""Edge class."""
def __init__(self, sid: str, tid: str, **props):
"""Initialize Edge."""
super().__init__()
self._sid = sid
self._tid = tid
for k, v in props.items():
self.set_prop(k, v)
@property
def sid(self) -> str:
"""Return the source vertex ID of the edge."""
return self._sid
@property
def tid(self) -> str:
"""Return the target vertex ID of the edge."""
return self._tid
def nid(self, vid):
"""Return neighbor id."""
if vid == self._sid:
return self._tid
elif vid == self._tid:
return self._sid
else:
raise ValueError(f"Get nid of {vid} on {self} failed")
def format(self, label_key: Optional[str] = None):
"""Format the edge properties into a string."""
label = self.get_prop(label_key) if label_key else ""
props_str = super().format(label_key)
if props_str == "{}":
return f"-[{label}]->" if label else "->"
else:
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
def triplet(self, label_key: str) -> Tuple[str, str, str]:
"""Return a triplet."""
assert label_key, "label key is needed"
return self._sid, str(self.get_prop(label_key)), self._tid
def __str__(self):
"""Return the edge '(sid)->(tid)'."""
return f"({self._sid})->({self._tid})"
class Graph(ABC):
"""Graph class."""
@abstractmethod
def upsert_vertex(self, vertex: Vertex):
"""Add a vertex."""
@abstractmethod
def append_edge(self, edge: Edge):
"""Add an edge."""
@abstractmethod
def has_vertex(self, vid: str) -> bool:
"""Check vertex exists."""
@abstractmethod
def get_vertex(self, vid: str) -> Vertex:
"""Get a vertex."""
@abstractmethod
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get neighbor edges."""
@abstractmethod
def vertices(self) -> Iterator[Vertex]:
"""Get vertex iterator."""
@abstractmethod
def edges(self) -> Iterator[Edge]:
"""Get edge iterator."""
@abstractmethod
def del_vertices(self, *vids: str):
"""Delete vertices and their neighbor edges."""
@abstractmethod
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges(sid -> tid) matches props."""
@abstractmethod
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete neighbor edges."""
@abstractmethod
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "Graph":
"""Search on graph."""
@abstractmethod
def schema(self) -> Dict[str, Any]:
"""Get schema."""
@abstractmethod
def format(self) -> str:
"""Format graph data to string."""
class MemoryGraph(Graph):
"""Graph class."""
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
"""Initialize MemoryGraph with vertex label and edge label."""
assert edge_label, "Edge label is needed"
# metadata
self._vertex_label = vertex_label
self._edge_label = edge_label
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
self._edge_prop_keys = {edge_label}
self._edge_count = 0
# init vertices, out edges, in edges index
self._vs: Any = defaultdict()
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@property
def vertex_label(self):
"""Return the label for vertices."""
return self._vertex_label
@property
def edge_label(self):
"""Return the label for edges."""
return self._edge_label
@property
def vertex_prop_keys(self):
"""Return a set of property keys for vertices."""
return self._vertex_prop_keys
@property
def edge_prop_keys(self):
"""Return a set of property keys for edges."""
return self._edge_prop_keys
@property
def vertex_count(self):
"""Return the number of vertices in the graph."""
return len(self._vs)
@property
def edge_count(self):
"""Return the count of edges in the graph."""
return self._edge_count
def upsert_vertex(self, vertex: Vertex):
"""Insert or update a vertex based on its ID."""
if vertex.vid in self._vs:
self._vs[vertex.vid].props.update(vertex.props)
else:
self._vs[vertex.vid] = vertex
# update metadata
self._vertex_prop_keys.update(vertex.props.keys())
def append_edge(self, edge: Edge):
"""Append an edge if it doesn't exist; requires edge label."""
if self.edge_label not in edge.props.keys():
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
sid = edge.sid
tid = edge.tid
if edge in self._oes[sid][tid]:
return False
# init vertex index
self._vs.setdefault(sid, Vertex(sid))
self._vs.setdefault(tid, Vertex(tid))
# update edge index
self._oes[sid][tid].add(edge)
self._ies[tid][sid].add(edge)
# update metadata
self._edge_prop_keys.update(edge.props.keys())
self._edge_count += 1
return True
def has_vertex(self, vid: str) -> bool:
"""Retrieve a vertex by ID."""
return vid in self._vs
def get_vertex(self, vid: str) -> Vertex:
"""Retrieve a vertex by ID."""
return self._vs[vid]
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get edges connected to a vertex by direction."""
if direction == Direction.OUT:
es = (e for es in self._oes[vid].values() for e in es)
elif direction == Direction.IN:
es = iter(e for es in self._ies[vid].values() for e in es)
elif direction == Direction.BOTH:
oes = (e for es in self._oes[vid].values() for e in es)
ies = (e for es in self._ies[vid].values() for e in es)
# merge
tuples = itertools.zip_longest(oes, ies)
es = (e for t in tuples for e in t if e is not None)
# distinct
seen = set()
# es = (e for e in es if e not in seen and not seen.add(e))
def unique_elements(elements):
for element in elements:
if element not in seen:
seen.add(element)
yield element
es = unique_elements(es)
else:
raise ValueError(f"Invalid direction: {direction}")
return itertools.islice(es, limit) if limit else es
def vertices(self) -> Iterator[Vertex]:
"""Return vertices."""
return iter(self._vs.values())
def edges(self) -> Iterator[Edge]:
"""Return edges."""
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
def del_vertices(self, *vids: str):
"""Delete specified vertices."""
for vid in vids:
self.del_neighbor_edges(vid, Direction.BOTH)
self._vs.pop(vid, None)
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
if not props:
self._edge_count -= old_edge_cnt
self._oes[sid].pop(tid, None)
self._ies[tid].pop(sid, None)
return
def remove_matches(es):
return set(filter(lambda e: not e.has_props(**props), es))
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
self._edge_count -= old_edge_cnt - len(self._oes[sid][tid])
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete all neighbor edges."""
def del_index(idx, i_idx):
for nid in idx[vid].keys():
self._edge_count -= len(i_idx[nid][vid])
i_idx[nid].pop(vid, None)
idx.pop(vid, None)
if direction in [Direction.OUT, Direction.BOTH]:
del_index(self._oes, self._ies)
if direction in [Direction.IN, Direction.BOTH]:
del_index(self._ies, self._oes)
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "MemoryGraph":
"""Search the graph from a vertex with specified parameters."""
subgraph = MemoryGraph()
for vid in vids:
self.__search(vid, direct, depth, fan, limit, 0, set(), subgraph)
return subgraph
def __search(
self,
vid: str,
direct: Direction,
depth: Optional[int],
fan: Optional[int],
limit: Optional[int],
_depth: int,
_visited: Set,
_subgraph: "MemoryGraph",
):
if vid in _visited or depth and _depth >= depth:
return
# visit vertex
if not self.has_vertex(vid):
return
_subgraph.upsert_vertex(self.get_vertex(vid))
_visited.add(vid)
# visit edges
nids = set()
for edge in self.get_neighbor_edges(vid, direct, fan):
if limit and _subgraph.edge_count >= limit:
return
# append edge success then visit new vertex
if _subgraph.append_edge(edge):
nid = edge.nid(vid)
if nid not in _visited:
nids.add(nid)
# next hop
for nid in nids:
self.__search(
nid, direct, depth, fan, limit, _depth + 1, _visited, _subgraph
)
def schema(self) -> Dict[str, Any]:
"""Return schema."""
return {
"schema": [
{
"type": "VERTEX",
"label": f"{self._vertex_label}",
"properties": [{"name": k} for k in self._vertex_prop_keys],
},
{
"type": "EDGE",
"label": f"{self._edge_label}",
"properties": [{"name": k} for k in self._edge_prop_keys],
},
]
}
def format(self) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
es_str = "\n".join(
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
f"{e.format(self.edge_label)}"
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
for e in self.edges()
)
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
def graphviz(self, name="g"):
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
g = nx.MultiDiGraph()
for vertex in self.vertices():
g.add_node(vertex.vid)
for edge in self.edges():
triplet = edge.triplet(self.edge_label)
g.add_edge(triplet[0], triplet[2], label=triplet[1])
digraph = nx.nx_agraph.to_agraph(g).to_string()
digraph = digraph.replace('digraph ""', f"digraph {name}")
digraph = re.sub(r"key=\d+,?\s*", "", digraph)
return digraph

View File

@ -0,0 +1,81 @@
"""Graph store base class."""
import json
import logging
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
logger = logging.getLogger(__name__)
class MemoryGraphStoreConfig(GraphStoreConfig):
"""Memory graph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
)
class MemoryGraphStore(GraphStoreBase):
"""Memory graph store."""
def __init__(self, graph_store_config: MemoryGraphStoreConfig):
"""Initialize MemoryGraphStore with a memory graph."""
self._edge_name_key = graph_store_config.edge_name_key
self._graph = MemoryGraph(edge_label=self._edge_name_key)
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert a triplet into the graph."""
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Retrieve triplets originating from a subject."""
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
return [(e.get_prop(self._edge_name_key), e.tid) for e in subgraph.edges()]
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete a specific triplet from the graph."""
self._graph.del_edges(sub, obj, **{self._edge_name_key: rel})
def drop(self):
"""Drop graph."""
self._graph = None
def get_schema(self, refresh: bool = False) -> str:
"""Return the graph schema as a JSON string."""
return json.dumps(self._graph.schema())
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
"""Return self."""
if not limit:
return self._graph
subgraph = MemoryGraph()
for count, edge in enumerate(self._graph.edges()):
if count >= limit:
break
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
subgraph.append_edge(edge)
count += 1
return subgraph
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph.search(subs, direct, depth, fan, limit)
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
raise NotImplementedError("Query memory graph not allowed")

View File

@ -0,0 +1,64 @@
"""Neo4j vector store."""
import logging
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Graph, MemoryGraph
logger = logging.getLogger(__name__)
class Neo4jStoreConfig(GraphStoreConfig):
"""Neo4j store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class Neo4jStore(GraphStoreBase):
"""Neo4j graph store."""
# todo: add neo4j implementation
def __init__(self, graph_store_config: Neo4jStoreConfig):
"""Initialize the Neo4jStore with connection details."""
pass
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert triplets."""
pass
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
return []
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplets."""
pass
def drop(self):
"""Drop graph."""
pass
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
return ""
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
return MemoryGraph()
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
return MemoryGraph()
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
return MemoryGraph()

View File

@ -0,0 +1,239 @@
"""TuGraph vector store."""
import logging
import os
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
logger = logging.getLogger(__name__)
class TuGraphStoreConfig(GraphStoreConfig):
"""TuGraph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
host: str = Field(
default="127.0.0.1",
description="TuGraph host",
)
port: int = Field(
default=7687,
description="TuGraph port",
)
username: str = Field(
default="admin",
description="login username",
)
password: str = Field(
default="123456",
description="login password",
)
vertex_type: str = Field(
default="entity",
description="The type of graph vertex, `entity` by default.",
)
edge_type: str = Field(
default="relation",
description="The type of graph edge, `relation` by default.",
)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
)
class TuGraphStore(GraphStoreBase):
"""TuGraph graph store."""
def __init__(self, config: TuGraphStoreConfig) -> None:
"""Initialize the TuGraphStore with connection details."""
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
self._node_label = (
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
)
self._edge_label = (
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
)
self.edge_name_key = (
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
)
self._graph_name = config.name
self.conn = TuGraphConnector.from_uri_db(
host=self._host,
port=self._port,
user=self._username,
pwd=self._password,
db_name=config.name,
)
self.conn.create_graph(graph_name=config.name)
self._create_schema()
def _check_label(self, elem_type: str):
result = self.conn.get_table_names()
if elem_type == "vertex":
return self._node_label in result["vertex_tables"]
if elem_type == "edge":
return self._edge_label in result["edge_tables"]
def _create_schema(self):
if not self._check_label("vertex"):
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
self.conn.run(create_edge_gql)
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.conn.run(query)
return [(record["rel"], record["obj"]) for record in data]
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
return value.replace("'", "\\'").replace('"', '\\"')
subj_escaped = escape_quotes(subj)
rel_escaped = escape_quotes(rel)
obj_escaped = escape_quotes(obj)
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
rel_query = (
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
)
self.conn.run(query=subj_query)
self.conn.run(query=obj_query)
self.conn.run(query=rel_query)
def drop(self):
"""Delete Graph."""
self.conn.delete_graph(self._graph_name)
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.conn.run(query=del_query)
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the graph store."""
query = "CALL dbms.graph.getGraphSchema()"
data = self.conn.run(query=query)
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
"""Get full graph."""
if not limit:
raise Exception("limit must be set")
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
if fan is not None:
raise ValueError("Fan functionality is not supported at this time.")
else:
depth_string = f"1..{depth}"
if depth is None:
depth_string = ".."
limit_string = f"LIMIT {limit}"
if limit is None:
limit_string = ""
query = (
f"MATCH p=(n:{self._node_label})"
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
def query(self, query: str, **args) -> MemoryGraph:
"""Execute a query on graph."""
def _format_paths(paths):
formatted_paths = []
for path in paths:
formatted_path = []
nodes = list(path["p"].nodes)
rels = list(path["p"].relationships)
for i in range(len(nodes)):
formatted_path.append(nodes[i]._properties["id"])
if i < len(rels):
formatted_path.append(rels[i]._properties["id"])
formatted_paths.append(formatted_path)
return formatted_paths
def _format_query_data(data):
node_ids_set = set()
rels_set = set()
from neo4j import graph
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
node_ids_set.add(node_id)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
rels_set.add((src_id, dst_id, prop_id))
elif isinstance(value, graph.Path):
formatted_paths = _format_paths(data)
for path in formatted_paths:
for i in range(0, len(path), 2):
node_ids_set.add(path[i])
if i + 2 < len(path):
rels_set.add((path[i], path[i + 2], path[i + 1]))
nodes = [Vertex(node_id) for node_id in node_ids_set]
rels = [
Edge(src_id, dst_id, label=prop_id)
for (src_id, dst_id, prop_id) in rels_set
]
return {"nodes": nodes, "edges": rels}
result = self.conn.run(query=query)
graph = _format_query_data(result)
mg = MemoryGraph()
for vertex in graph["nodes"]:
mg.upsert_vertex(vertex)
for edge in graph["edges"]:
mg.append_edge(edge)
return mg

View File

@ -0,0 +1 @@
"""Module for KG."""

Some files were not shown because too many files have changed in this diff Show More