mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 03:19:19 +00:00
fix: fix unit test error (#2085)
Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
parent
6d6667812b
commit
d9e20426fe
@ -1,7 +1,7 @@
|
||||
"""TuGraph Connector."""
|
||||
|
||||
import json
|
||||
from typing import Dict, Generator, List, Tuple, cast
|
||||
from typing import Dict, Generator, Iterator, List, cast
|
||||
|
||||
from .base import BaseConnector
|
||||
|
||||
@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
|
||||
self._graph = graph
|
||||
self._session = None
|
||||
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
def create_graph(self, graph_name: str) -> bool:
|
||||
"""Create a new graph in the database if it doesn't already exist."""
|
||||
try:
|
||||
with self._driver.session(database="default") as session:
|
||||
@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
|
||||
|
||||
return not exists
|
||||
|
||||
def delete_graph(self, graph_name: str) -> None:
|
||||
"""Delete a graph in the database if it exists."""
|
||||
with self._driver.session(database="default") as session:
|
||||
@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
|
||||
"`pip install neo4j`"
|
||||
) from err
|
||||
|
||||
def get_table_names(self) -> Tuple[List[str], List[str]]:
|
||||
def get_table_names(self) -> Iterator[str]:
|
||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
# Run the query to get vertex labels
|
||||
raw_vertex_labels: Dict[str, str] = session.run(
|
||||
"CALL db.vertexLabels()"
|
||||
).data()
|
||||
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
|
||||
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
|
||||
|
||||
# Run the query to get edge labels
|
||||
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
|
||||
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
|
||||
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
|
||||
|
||||
return vertex_labels, edge_labels
|
||||
return iter(vertex_labels + edge_labels)
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
|
@ -76,8 +76,8 @@ def _parse_db_summary(
|
||||
table_info_summaries = None
|
||||
if isinstance(conn, TuGraphConnector):
|
||||
table_names = conn.get_table_names()
|
||||
v_tables = table_names.get("vertex_tables", [])
|
||||
e_tables = table_names.get("edge_tables", [])
|
||||
v_tables = table_names.get("vertex_tables", []) # type: ignore
|
||||
e_tables = table_names.get("edge_tables", []) # type: ignore
|
||||
table_info_summaries = [
|
||||
_parse_table_summary(conn, summary_template, table_name, "vertex")
|
||||
for table_name in v_tables
|
||||
|
@ -141,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
|
||||
if len(missing_plugins):
|
||||
for name in missing_plugins:
|
||||
try:
|
||||
from dbgpt_tugraph_plugins import (
|
||||
get_plugin_binary_path, # type:ignore[import-untyped]
|
||||
from dbgpt_tugraph_plugins import ( # type: ignore
|
||||
get_plugin_binary_path,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
@ -150,7 +150,7 @@ class TuGraphStore(GraphStoreBase):
|
||||
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
|
||||
"https://pypi.org/simple"
|
||||
)
|
||||
plugin_path = get_plugin_binary_path("leiden")
|
||||
plugin_path = get_plugin_binary_path("leiden") # type: ignore
|
||||
with open(plugin_path, "rb") as f:
|
||||
content = f.read()
|
||||
content = base64.b64encode(content).decode()
|
||||
|
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Iterator, List, Optional, Union
|
||||
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
@ -156,7 +156,11 @@ class GraphStoreAdapter(ABC):
|
||||
"""Create graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph_label(self) -> None:
|
||||
def create_graph_label(
|
||||
self,
|
||||
graph_elem_type: GraphElemType,
|
||||
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
@ -176,7 +180,12 @@ class GraphStoreAdapter(ABC):
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
@ -173,6 +173,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
graph_elem_type: GraphElemType,
|
||||
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
@ -201,9 +203,12 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int | None = None,
|
||||
fan: int | None = None,
|
||||
limit: int | None = None,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
||||
|
@ -79,7 +79,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
@property
|
||||
def graph_store(self) -> TuGraphStore:
|
||||
"""Get the graph store."""
|
||||
return self._graph_store
|
||||
return self._graph_store # type: ignore[return-value]
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
@ -176,29 +176,23 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
[{self._convert_dict_to_str(edge_list)}])"""
|
||||
self.graph_store.conn.run(query=relation_query)
|
||||
|
||||
def upsert_chunks(
|
||||
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
|
||||
) -> None:
|
||||
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
|
||||
"""Upsert chunks."""
|
||||
chunks_list = list(chunks)
|
||||
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self._escape_quotes(chunk.chunk_id),
|
||||
"name": self._escape_quotes(chunk.chunk_name),
|
||||
"content": self._escape_quotes(chunk.content),
|
||||
}
|
||||
for chunk in chunks_list
|
||||
]
|
||||
else:
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self._escape_quotes(chunk.vid),
|
||||
"name": self._escape_quotes(chunk.name),
|
||||
"content": self._escape_quotes(chunk.get_prop("content")),
|
||||
}
|
||||
for chunk in chunks_list
|
||||
]
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self._escape_quotes(chunk.chunk_id),
|
||||
"name": self._escape_quotes(chunk.chunk_name),
|
||||
"content": self._escape_quotes(chunk.content),
|
||||
}
|
||||
if isinstance(chunk, ParagraphChunk)
|
||||
else {
|
||||
"id": self._escape_quotes(chunk.vid),
|
||||
"name": self._escape_quotes(chunk.name),
|
||||
"content": self._escape_quotes(chunk.get_prop("content")),
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
chunk_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.CHUNK.value}", '
|
||||
@ -207,28 +201,24 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
self.graph_store.conn.run(query=chunk_query)
|
||||
|
||||
def upsert_documents(
|
||||
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
|
||||
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
|
||||
) -> None:
|
||||
"""Upsert documents."""
|
||||
documents_list = list(documents)
|
||||
if documents_list and isinstance(documents_list[0], ParagraphChunk):
|
||||
document_list = [
|
||||
{
|
||||
"id": self._escape_quotes(document.chunk_id),
|
||||
"name": self._escape_quotes(document.chunk_name),
|
||||
"content": "",
|
||||
}
|
||||
for document in documents_list
|
||||
]
|
||||
else:
|
||||
document_list = [
|
||||
{
|
||||
"id": self._escape_quotes(document.vid),
|
||||
"name": self._escape_quotes(document.name),
|
||||
"content": self._escape_quotes(document.get_prop("content")) or "",
|
||||
}
|
||||
for document in documents_list
|
||||
]
|
||||
document_list = [
|
||||
{
|
||||
"id": self._escape_quotes(document.chunk_id),
|
||||
"name": self._escape_quotes(document.chunk_name),
|
||||
"content": "",
|
||||
}
|
||||
if isinstance(document, ParagraphChunk)
|
||||
else {
|
||||
"id": self._escape_quotes(document.vid),
|
||||
"name": self._escape_quotes(document.name),
|
||||
"content": "",
|
||||
}
|
||||
for document in documents
|
||||
]
|
||||
|
||||
document_query = (
|
||||
"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.DOCUMENT.value}", '
|
||||
@ -258,7 +248,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
self.graph_store.conn.run(query=vertex_query)
|
||||
self.graph_store.conn.run(query=edge_query)
|
||||
|
||||
def upsert_graph(self, graph: MemoryGraph) -> None:
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
@ -362,7 +352,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
self.graph_store.conn.create_graph(graph_name=graph_name)
|
||||
if not self.graph_store.conn.create_graph(graph_name=graph_name):
|
||||
return
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
@ -474,12 +465,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
if graph_elem_type.is_vertex(): # vertex
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
"primary": "id",
|
||||
"properties": graph_properties,
|
||||
})
|
||||
data = json.dumps(
|
||||
{
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
"primary": "id",
|
||||
"properties": graph_properties,
|
||||
}
|
||||
)
|
||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||
else: # edge
|
||||
|
||||
@ -505,12 +498,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
else:
|
||||
raise ValueError("Invalid graph element type.")
|
||||
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
"constraints": edge_direction(graph_elem_type),
|
||||
"properties": graph_properties,
|
||||
})
|
||||
data = json.dumps(
|
||||
{
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
"constraints": edge_direction(graph_elem_type),
|
||||
"properties": graph_properties,
|
||||
}
|
||||
)
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||
|
||||
self.graph_store.conn.run(gql)
|
||||
@ -530,18 +525,16 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
|
||||
tables = self.graph_store.conn.get_table_names()
|
||||
|
||||
if graph_elem_type.is_vertex():
|
||||
return graph_elem_type in vertex_tables
|
||||
else:
|
||||
return graph_elem_type in edge_tables
|
||||
return graph_elem_type.value in tables
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
@ -621,11 +614,17 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
# type: ignore[override]
|
||||
# mypy: ignore-errors
|
||||
async def stream_query( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
async for record in self.graph_store.conn.run_stream(query):
|
||||
async for record in self.graph_store.conn.run_stream(query): # type: ignore
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
@ -650,15 +649,19 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append({
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
})
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
if i < len(rels):
|
||||
formatted_path.append({
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
})
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
|
@ -149,25 +149,25 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]:
|
||||
async def _aload_document_graph(self, chunks: List[Chunk]) -> None:
|
||||
"""Load the knowledge graph from the chunks.
|
||||
|
||||
The chunks include the doc structure.
|
||||
"""
|
||||
if not self._graph_store.get_config().document_graph_enabled:
|
||||
return []
|
||||
return
|
||||
|
||||
chunks: List[ParagraphChunk] = [
|
||||
_chunks: List[ParagraphChunk] = [
|
||||
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
|
||||
]
|
||||
documment_chunk, chunks = self._load_chunks(chunks)
|
||||
documment_chunk, paragraph_chunks = self._load_chunks(_chunks)
|
||||
|
||||
# upsert the document and chunks vertices
|
||||
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
|
||||
self._graph_store_apdater.upsert_chunks(iter(chunks))
|
||||
self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks))
|
||||
|
||||
# upsert the document structure
|
||||
for chunk_index, chunk in enumerate(chunks):
|
||||
for chunk_index, chunk in enumerate(paragraph_chunks):
|
||||
# document -> include -> chunk
|
||||
if chunk.parent_is_document:
|
||||
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
|
||||
@ -177,7 +177,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
# chunk -> next -> chunk
|
||||
if chunk_index >= 1:
|
||||
self._graph_store_apdater.upsert_chunk_next_chunk(
|
||||
chunk=chunks[chunk_index - 1], next_chunk=chunk
|
||||
chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
|
||||
)
|
||||
|
||||
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
|
||||
@ -280,7 +280,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
pass
|
||||
return []
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
@ -301,9 +301,6 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||
|
||||
# Local search: extract keywords and explore subgraph
|
||||
subgraph = MemoryGraph()
|
||||
subgraph_for_doc = MemoryGraph()
|
||||
|
||||
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
|
||||
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
|
||||
|
||||
@ -329,9 +326,10 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
limit=self._config.knowledge_graph_chunk_search_top_size,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
|
||||
knowledge_graph_str = subgraph.format()
|
||||
knowledge_graph_for_doc_str = subgraph_for_doc.format()
|
||||
knowledge_graph_str = subgraph.format() if subgraph else ""
|
||||
knowledge_graph_for_doc_str = (
|
||||
subgraph_for_doc.format() if subgraph_for_doc else ""
|
||||
)
|
||||
|
||||
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
|
||||
|
||||
|
@ -183,5 +183,5 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete by ids."""
|
||||
self._graph_store_apdater.delete_document(chunk_ids=ids)
|
||||
self._graph_store_apdater.delete_document(chunk_id=ids)
|
||||
return []
|
||||
|
@ -88,7 +88,8 @@ def __create_community_kg_connector():
|
||||
|
||||
async def ask_chunk(chunk: Chunk, question) -> str:
|
||||
rag_template = (
|
||||
"Based on the following [Context] {context}, " "answer [Question] {question}."
|
||||
"Based on the following [Context] {context}, "
|
||||
"answer [Question] {question}."
|
||||
)
|
||||
template = HumanPromptTemplate.from_template(rag_template)
|
||||
messages = template.format_messages(context=chunk.content, question=question)
|
||||
|
@ -46,9 +46,9 @@ def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
assert len(triplets) == 1
|
||||
|
||||
|
||||
def test_query(store: TuGraphStore):
|
||||
def test_query(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
|
||||
result = store.query(query)
|
||||
result = tugraph_store_adapter.query(query)
|
||||
v_c = result.vertex_count
|
||||
e_c = result.edge_count
|
||||
assert v_c == 3 and e_c == 3
|
||||
|
@ -25,12 +25,12 @@ def g():
|
||||
(lambda g: g.del_vertices("G", "G"), 6, 9),
|
||||
(lambda g: g.del_vertices("C"), 6, 7),
|
||||
(lambda g: g.del_vertices("A", "G"), 5, 6),
|
||||
(lambda g: g.del_edges("A", "A"), 7, 7),
|
||||
(lambda g: g.del_edges("A", "B"), 7, 8),
|
||||
(lambda g: g.del_edges("A", "A", None), 7, 7),
|
||||
(lambda g: g.del_edges("A", "B", None), 7, 8),
|
||||
(lambda g: g.del_edges("A", "A", "0"), 7, 8),
|
||||
(lambda g: g.del_edges("E", "F", "8"), 7, 8),
|
||||
(lambda g: g.del_edges("E", "F", "9"), 7, 9),
|
||||
(lambda g: g.del_edges("E", "F", val=1), 7, 9),
|
||||
(lambda g: g.del_edges("E", "F", None, val=1), 7, 9),
|
||||
(lambda g: g.del_edges("E", "F", "8", val=1), 7, 9),
|
||||
(lambda g: g.del_edges("E", "F", "9", val=1), 7, 9),
|
||||
(lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),
|
||||
|
Loading…
Reference in New Issue
Block a user