fix: fix unit test error (#2085)

Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
Florian
2024-10-22 09:35:51 +08:00
committed by GitHub
parent 6d6667812b
commit d9e20426fe
11 changed files with 129 additions and 113 deletions

View File

@@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import (
@@ -156,7 +156,11 @@ class GraphStoreAdapter(ABC):
"""Create graph."""
@abstractmethod
def create_graph_label(self) -> None:
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
@@ -176,7 +180,12 @@ class GraphStoreAdapter(ABC):
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""

View File

@@ -2,7 +2,7 @@
import json
import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
from dbgpt.storage.graph_store.graph import (
Direction,
@@ -173,6 +173,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
@@ -201,9 +203,12 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int | None = None,
fan: int | None = None,
limit: int | None = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph_store._graph.search(subs, direct, depth, fan, limit)

View File

@@ -79,7 +79,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
@property
def graph_store(self) -> TuGraphStore:
"""Get the graph store."""
return self._graph_store
return self._graph_store # type: ignore[return-value]
def get_graph_config(self):
"""Get the graph store config."""
@@ -176,29 +176,23 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
[{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query)
def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
"""Upsert chunks."""
chunks_list = list(chunks)
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
for chunk in chunks_list
]
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks_list
]
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
if isinstance(chunk, ParagraphChunk)
else {
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks
]
chunk_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", '
@@ -207,28 +201,24 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=chunk_query)
def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
) -> None:
"""Upsert documents."""
documents_list = list(documents)
if documents_list and isinstance(documents_list[0], ParagraphChunk):
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
for document in documents_list
]
else:
document_list = [
{
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "",
}
for document in documents_list
]
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
if isinstance(document, ParagraphChunk)
else {
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": "",
}
for document in documents
]
document_query = (
"CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", '
@@ -258,7 +248,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=vertex_query)
self.graph_store.conn.run(query=edge_query)
def upsert_graph(self, graph: MemoryGraph) -> None:
def upsert_graph(self, graph: Graph) -> None:
"""Add graph to the graph store.
Args:
@@ -362,7 +352,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
def create_graph(self, graph_name: str):
"""Create a graph."""
self.graph_store.conn.create_graph(graph_name=graph_name)
if not self.graph_store.conn.create_graph(graph_name=graph_name):
return
# Create the graph schema
def _format_graph_propertity_schema(
@@ -474,12 +465,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
(vertices) and edges in the graph.
"""
if graph_elem_type.is_vertex(): # vertex
data = json.dumps({
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
}
)
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
else: # edge
@@ -505,12 +498,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
else:
raise ValueError("Invalid graph element type.")
data = json.dumps({
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
}
)
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
self.graph_store.conn.run(gql)
@@ -530,18 +525,16 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
True if the label exists in the specified graph element type, otherwise
False.
"""
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
tables = self.graph_store.conn.get_table_names()
if graph_elem_type.is_vertex():
return graph_elem_type in vertex_tables
else:
return graph_elem_type in edge_tables
return graph_elem_type.value in tables
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
@@ -621,11 +614,17 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
mg.append_edge(edge)
return mg
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
# type: ignore[override]
# mypy: ignore-errors
async def stream_query( # type: ignore[override]
self,
query: str,
**kwargs,
) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
from neo4j import graph
async for record in self.graph_store.conn.run_stream(query):
async for record in self.graph_store.conn.run_stream(query): # type: ignore
mg = MemoryGraph()
for key in record.keys():
value = record[key]
@@ -650,15 +649,19 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append({
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
})
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append({
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
})
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(