fix: fix unit test error (#2085)

Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
Florian 2024-10-22 09:35:51 +08:00 committed by GitHub
parent 6d6667812b
commit d9e20426fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 129 additions and 113 deletions

View File

@ -1,7 +1,7 @@
"""TuGraph Connector."""
import json
from typing import Dict, Generator, List, Tuple, cast
from typing import Dict, Generator, Iterator, List, cast
from .base import BaseConnector
@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
self._graph = graph
self._session = None
def create_graph(self, graph_name: str) -> None:
def create_graph(self, graph_name: str) -> bool:
"""Create a new graph in the database if it doesn't already exist."""
try:
with self._driver.session(database="default") as session:
@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
return not exists
def delete_graph(self, graph_name: str) -> None:
"""Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session:
@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
"`pip install neo4j`"
) from err
def get_table_names(self) -> Tuple[List[str], List[str]]:
def get_table_names(self) -> Iterator[str]:
"""Get all table names from the TuGraph by Neo4j driver."""
with self._driver.session(database=self._graph) as session:
# Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run(
"CALL db.vertexLabels()"
).data()
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
# Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
return vertex_labels, edge_labels
return iter(vertex_labels + edge_labels)
def get_grants(self):
"""Get grants."""

View File

@ -76,8 +76,8 @@ def _parse_db_summary(
table_info_summaries = None
if isinstance(conn, TuGraphConnector):
table_names = conn.get_table_names()
v_tables = table_names.get("vertex_tables", [])
e_tables = table_names.get("edge_tables", [])
v_tables = table_names.get("vertex_tables", []) # type: ignore
e_tables = table_names.get("edge_tables", []) # type: ignore
table_info_summaries = [
_parse_table_summary(conn, summary_template, table_name, "vertex")
for table_name in v_tables

View File

@ -141,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
if len(missing_plugins):
for name in missing_plugins:
try:
from dbgpt_tugraph_plugins import (
get_plugin_binary_path, # type:ignore[import-untyped]
from dbgpt_tugraph_plugins import ( # type: ignore
get_plugin_binary_path,
)
except ImportError:
logger.error(
@ -150,7 +150,7 @@ class TuGraphStore(GraphStoreBase):
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
"https://pypi.org/simple"
)
plugin_path = get_plugin_binary_path("leiden")
plugin_path = get_plugin_binary_path("leiden") # type: ignore
with open(plugin_path, "rb") as f:
content = f.read()
content = base64.b64encode(content).decode()

View File

@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import (
@ -156,7 +156,11 @@ class GraphStoreAdapter(ABC):
"""Create graph."""
@abstractmethod
def create_graph_label(self) -> None:
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
@ -176,7 +180,12 @@ class GraphStoreAdapter(ABC):
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""

View File

@ -2,7 +2,7 @@
import json
import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
from dbgpt.storage.graph_store.graph import (
Direction,
@ -173,6 +173,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
@ -201,9 +203,12 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int | None = None,
fan: int | None = None,
limit: int | None = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph_store._graph.search(subs, direct, depth, fan, limit)

View File

@ -79,7 +79,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
@property
def graph_store(self) -> TuGraphStore:
"""Get the graph store."""
return self._graph_store
return self._graph_store # type: ignore[return-value]
def get_graph_config(self):
"""Get the graph store config."""
@ -176,29 +176,23 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
[{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query)
def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
"""Upsert chunks."""
chunks_list = list(chunks)
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
for chunk in chunks_list
]
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks_list
]
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
if isinstance(chunk, ParagraphChunk)
else {
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks
]
chunk_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", '
@ -207,28 +201,24 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=chunk_query)
def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
) -> None:
"""Upsert documents."""
documents_list = list(documents)
if documents_list and isinstance(documents_list[0], ParagraphChunk):
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
for document in documents_list
]
else:
document_list = [
{
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "",
}
for document in documents_list
]
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
if isinstance(document, ParagraphChunk)
else {
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": "",
}
for document in documents
]
document_query = (
"CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", '
@ -258,7 +248,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=vertex_query)
self.graph_store.conn.run(query=edge_query)
def upsert_graph(self, graph: MemoryGraph) -> None:
def upsert_graph(self, graph: Graph) -> None:
"""Add graph to the graph store.
Args:
@ -362,7 +352,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
def create_graph(self, graph_name: str):
"""Create a graph."""
self.graph_store.conn.create_graph(graph_name=graph_name)
if not self.graph_store.conn.create_graph(graph_name=graph_name):
return
# Create the graph schema
def _format_graph_propertity_schema(
@ -474,12 +465,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
(vertices) and edges in the graph.
"""
if graph_elem_type.is_vertex(): # vertex
data = json.dumps({
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
}
)
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
else: # edge
@ -505,12 +498,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
else:
raise ValueError("Invalid graph element type.")
data = json.dumps({
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
}
)
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
self.graph_store.conn.run(gql)
@ -530,18 +525,16 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
True if the label exists in the specified graph element type, otherwise
False.
"""
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
tables = self.graph_store.conn.get_table_names()
if graph_elem_type.is_vertex():
return graph_elem_type in vertex_tables
else:
return graph_elem_type in edge_tables
return graph_elem_type.value in tables
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
@ -621,11 +614,17 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
mg.append_edge(edge)
return mg
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
# type: ignore[override]
# mypy: ignore-errors
async def stream_query( # type: ignore[override]
self,
query: str,
**kwargs,
) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
from neo4j import graph
async for record in self.graph_store.conn.run_stream(query):
async for record in self.graph_store.conn.run_stream(query): # type: ignore
mg = MemoryGraph()
for key in record.keys():
value = record[key]
@ -650,15 +649,19 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append({
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
})
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append({
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
})
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(

View File

@ -149,25 +149,25 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
return [chunk.chunk_id for chunk in chunks]
async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]:
async def _aload_document_graph(self, chunks: List[Chunk]) -> None:
"""Load the knowledge graph from the chunks.
The chunks include the doc structure.
"""
if not self._graph_store.get_config().document_graph_enabled:
return []
return
chunks: List[ParagraphChunk] = [
_chunks: List[ParagraphChunk] = [
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
]
documment_chunk, chunks = self._load_chunks(chunks)
documment_chunk, paragraph_chunks = self._load_chunks(_chunks)
# upsert the document and chunks vertices
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
self._graph_store_apdater.upsert_chunks(iter(chunks))
self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks))
# upsert the document structure
for chunk_index, chunk in enumerate(chunks):
for chunk_index, chunk in enumerate(paragraph_chunks):
# document -> include -> chunk
if chunk.parent_is_document:
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
@ -177,7 +177,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
# chunk -> next -> chunk
if chunk_index >= 1:
self._graph_store_apdater.upsert_chunk_next_chunk(
chunk=chunks[chunk_index - 1], next_chunk=chunk
chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
)
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
@ -280,7 +280,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
Return:
List[Chunk]: The similar documents.
"""
pass
return []
async def asimilar_search_with_scores(
self,
@ -301,9 +301,6 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
keywords: List[str] = await self._keyword_extractor.extract(text)
# Local search: extract keywords and explore subgraph
subgraph = MemoryGraph()
subgraph_for_doc = MemoryGraph()
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
@ -329,9 +326,10 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph",
)
knowledge_graph_str = subgraph.format()
knowledge_graph_for_doc_str = subgraph_for_doc.format()
knowledge_graph_str = subgraph.format() if subgraph else ""
knowledge_graph_for_doc_str = (
subgraph_for_doc.format() if subgraph_for_doc else ""
)
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")

View File

@ -183,5 +183,5 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete by ids."""
self._graph_store_apdater.delete_document(chunk_ids=ids)
self._graph_store_apdater.delete_document(chunk_id=ids)
return []

View File

@ -88,7 +88,8 @@ def __create_community_kg_connector():
async def ask_chunk(chunk: Chunk, question) -> str:
rag_template = (
"Based on the following [Context] {context}, " "answer [Question] {question}."
"Based on the following [Context] {context}, "
"answer [Question] {question}."
)
template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(context=chunk.content, question=question)

View File

@ -46,9 +46,9 @@ def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
assert len(triplets) == 1
def test_query(store: TuGraphStore):
def test_query(tugraph_store_adapter: TuGraphStoreAdapter):
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
result = store.query(query)
result = tugraph_store_adapter.query(query)
v_c = result.vertex_count
e_c = result.edge_count
assert v_c == 3 and e_c == 3

View File

@ -25,12 +25,12 @@ def g():
(lambda g: g.del_vertices("G", "G"), 6, 9),
(lambda g: g.del_vertices("C"), 6, 7),
(lambda g: g.del_vertices("A", "G"), 5, 6),
(lambda g: g.del_edges("A", "A"), 7, 7),
(lambda g: g.del_edges("A", "B"), 7, 8),
(lambda g: g.del_edges("A", "A", None), 7, 7),
(lambda g: g.del_edges("A", "B", None), 7, 8),
(lambda g: g.del_edges("A", "A", "0"), 7, 8),
(lambda g: g.del_edges("E", "F", "8"), 7, 8),
(lambda g: g.del_edges("E", "F", "9"), 7, 9),
(lambda g: g.del_edges("E", "F", val=1), 7, 9),
(lambda g: g.del_edges("E", "F", None, val=1), 7, 9),
(lambda g: g.del_edges("E", "F", "8", val=1), 7, 9),
(lambda g: g.del_edges("E", "F", "9", val=1), 7, 9),
(lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),