fix: fix unit test error (#2085)

Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
Florian 2024-10-22 09:35:51 +08:00 committed by GitHub
parent 6d6667812b
commit d9e20426fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 129 additions and 113 deletions

View File

@ -1,7 +1,7 @@
"""TuGraph Connector.""" """TuGraph Connector."""
import json import json
from typing import Dict, Generator, List, Tuple, cast from typing import Dict, Generator, Iterator, List, cast
from .base import BaseConnector from .base import BaseConnector
@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
self._graph = graph self._graph = graph
self._session = None self._session = None
def create_graph(self, graph_name: str) -> None: def create_graph(self, graph_name: str) -> bool:
"""Create a new graph in the database if it doesn't already exist.""" """Create a new graph in the database if it doesn't already exist."""
try: try:
with self._driver.session(database="default") as session: with self._driver.session(database="default") as session:
@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
except Exception as e: except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
return not exists
def delete_graph(self, graph_name: str) -> None: def delete_graph(self, graph_name: str) -> None:
"""Delete a graph in the database if it exists.""" """Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session: with self._driver.session(database="default") as session:
@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
"`pip install neo4j`" "`pip install neo4j`"
) from err ) from err
def get_table_names(self) -> Tuple[List[str], List[str]]: def get_table_names(self) -> Iterator[str]:
"""Get all table names from the TuGraph by Neo4j driver.""" """Get all table names from the TuGraph by Neo4j driver."""
with self._driver.session(database=self._graph) as session: with self._driver.session(database=self._graph) as session:
# Run the query to get vertex labels # Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run( raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
"CALL db.vertexLabels()"
).data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels] vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
# Run the query to get edge labels # Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data() raw_edge_labels = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels] edge_labels = [table_name["label"] for table_name in raw_edge_labels]
return vertex_labels, edge_labels return iter(vertex_labels + edge_labels)
def get_grants(self): def get_grants(self):
"""Get grants.""" """Get grants."""

View File

@ -76,8 +76,8 @@ def _parse_db_summary(
table_info_summaries = None table_info_summaries = None
if isinstance(conn, TuGraphConnector): if isinstance(conn, TuGraphConnector):
table_names = conn.get_table_names() table_names = conn.get_table_names()
v_tables = table_names.get("vertex_tables", []) v_tables = table_names.get("vertex_tables", []) # type: ignore
e_tables = table_names.get("edge_tables", []) e_tables = table_names.get("edge_tables", []) # type: ignore
table_info_summaries = [ table_info_summaries = [
_parse_table_summary(conn, summary_template, table_name, "vertex") _parse_table_summary(conn, summary_template, table_name, "vertex")
for table_name in v_tables for table_name in v_tables

View File

@ -141,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
if len(missing_plugins): if len(missing_plugins):
for name in missing_plugins: for name in missing_plugins:
try: try:
from dbgpt_tugraph_plugins import ( from dbgpt_tugraph_plugins import ( # type: ignore
get_plugin_binary_path, # type:ignore[import-untyped] get_plugin_binary_path,
) )
except ImportError: except ImportError:
logger.error( logger.error(
@ -150,7 +150,7 @@ class TuGraphStore(GraphStoreBase):
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i " "pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
"https://pypi.org/simple" "https://pypi.org/simple"
) )
plugin_path = get_plugin_binary_path("leiden") plugin_path = get_plugin_binary_path("leiden") # type: ignore
with open(plugin_path, "rb") as f: with open(plugin_path, "rb") as f:
content = f.read() content = f.read()
content = base64.b64encode(content).decode() content = base64.b64encode(content).decode()

View File

@ -3,7 +3,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional, Union from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from dbgpt.storage.graph_store.base import GraphStoreBase from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import ( from dbgpt.storage.graph_store.graph import (
@ -156,7 +156,11 @@ class GraphStoreAdapter(ABC):
"""Create graph.""" """Create graph."""
@abstractmethod @abstractmethod
def create_graph_label(self) -> None: def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label. """Create a graph label.
The graph label is used to identify and distinguish different types of nodes The graph label is used to identify and distinguish different types of nodes
@ -176,7 +180,12 @@ class GraphStoreAdapter(ABC):
self, self,
subs: List[str], subs: List[str],
direct: Direction = Direction.BOTH, direct: Direction = Direction.BOTH,
depth: Optional[int] = None, depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph: ) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth.""" """Explore the graph from given subjects up to a depth."""

View File

@ -2,7 +2,7 @@
import json import json
import logging import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
from dbgpt.storage.graph_store.graph import ( from dbgpt.storage.graph_store.graph import (
Direction, Direction,
@ -173,6 +173,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
def create_graph_label( def create_graph_label(
self, self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None: ) -> None:
"""Create a graph label. """Create a graph label.
@ -201,9 +203,12 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
self, self,
subs: List[str], subs: List[str],
direct: Direction = Direction.BOTH, direct: Direction = Direction.BOTH,
depth: int | None = None, depth: int = 3,
fan: int | None = None, fan: Optional[int] = None,
limit: int | None = None, limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph: ) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth.""" """Explore the graph from given subjects up to a depth."""
return self._graph_store._graph.search(subs, direct, depth, fan, limit) return self._graph_store._graph.search(subs, direct, depth, fan, limit)

View File

@ -79,7 +79,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
@property @property
def graph_store(self) -> TuGraphStore: def graph_store(self) -> TuGraphStore:
"""Get the graph store.""" """Get the graph store."""
return self._graph_store return self._graph_store # type: ignore[return-value]
def get_graph_config(self): def get_graph_config(self):
"""Get the graph store config.""" """Get the graph store config."""
@ -176,29 +176,23 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
[{self._convert_dict_to_str(edge_list)}])""" [{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query) self.graph_store.conn.run(query=relation_query)
def upsert_chunks( def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunks.""" """Upsert chunks."""
chunks_list = list(chunks)
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
chunk_list = [ chunk_list = [
{ {
"id": self._escape_quotes(chunk.chunk_id), "id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name), "name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content), "content": self._escape_quotes(chunk.content),
} }
for chunk in chunks_list if isinstance(chunk, ParagraphChunk)
] else {
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid), "id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name), "name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")), "content": self._escape_quotes(chunk.get_prop("content")),
} }
for chunk in chunks_list for chunk in chunks
] ]
chunk_query = ( chunk_query = (
f"CALL db.upsertVertex(" f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", ' f'"{GraphElemType.CHUNK.value}", '
@ -207,28 +201,24 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=chunk_query) self.graph_store.conn.run(query=chunk_query)
def upsert_documents( def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]] self, documents: Iterator[Union[Vertex, ParagraphChunk]]
) -> None: ) -> None:
"""Upsert documents.""" """Upsert documents."""
documents_list = list(documents)
if documents_list and isinstance(documents_list[0], ParagraphChunk):
document_list = [ document_list = [
{ {
"id": self._escape_quotes(document.chunk_id), "id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name), "name": self._escape_quotes(document.chunk_name),
"content": "", "content": "",
} }
for document in documents_list if isinstance(document, ParagraphChunk)
] else {
else:
document_list = [
{
"id": self._escape_quotes(document.vid), "id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name), "name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "", "content": "",
} }
for document in documents_list for document in documents
] ]
document_query = ( document_query = (
"CALL db.upsertVertex(" "CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", ' f'"{GraphElemType.DOCUMENT.value}", '
@ -258,7 +248,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self.graph_store.conn.run(query=vertex_query) self.graph_store.conn.run(query=vertex_query)
self.graph_store.conn.run(query=edge_query) self.graph_store.conn.run(query=edge_query)
def upsert_graph(self, graph: MemoryGraph) -> None: def upsert_graph(self, graph: Graph) -> None:
"""Add graph to the graph store. """Add graph to the graph store.
Args: Args:
@ -362,7 +352,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
def create_graph(self, graph_name: str): def create_graph(self, graph_name: str):
"""Create a graph.""" """Create a graph."""
self.graph_store.conn.create_graph(graph_name=graph_name) if not self.graph_store.conn.create_graph(graph_name=graph_name):
return
# Create the graph schema # Create the graph schema
def _format_graph_propertity_schema( def _format_graph_propertity_schema(
@ -474,12 +465,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
(vertices) and edges in the graph. (vertices) and edges in the graph.
""" """
if graph_elem_type.is_vertex(): # vertex if graph_elem_type.is_vertex(): # vertex
data = json.dumps({ data = json.dumps(
{
"label": graph_elem_type.value, "label": graph_elem_type.value,
"type": "VERTEX", "type": "VERTEX",
"primary": "id", "primary": "id",
"properties": graph_properties, "properties": graph_properties,
}) }
)
gql = f"""CALL db.createVertexLabelByJson('{data}')""" gql = f"""CALL db.createVertexLabelByJson('{data}')"""
else: # edge else: # edge
@ -505,12 +498,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
else: else:
raise ValueError("Invalid graph element type.") raise ValueError("Invalid graph element type.")
data = json.dumps({ data = json.dumps(
{
"label": graph_elem_type.value, "label": graph_elem_type.value,
"type": "EDGE", "type": "EDGE",
"constraints": edge_direction(graph_elem_type), "constraints": edge_direction(graph_elem_type),
"properties": graph_properties, "properties": graph_properties,
}) }
)
gql = f"""CALL db.createEdgeLabelByJson('{data}')""" gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
self.graph_store.conn.run(gql) self.graph_store.conn.run(gql)
@ -530,18 +525,16 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
True if the label exists in the specified graph element type, otherwise True if the label exists in the specified graph element type, otherwise
False. False.
""" """
vertex_tables, edge_tables = self.graph_store.conn.get_table_names() tables = self.graph_store.conn.get_table_names()
if graph_elem_type.is_vertex(): return graph_elem_type.value in tables
return graph_elem_type in vertex_tables
else:
return graph_elem_type in edge_tables
def explore( def explore(
self, self,
subs: List[str], subs: List[str],
direct: Direction = Direction.BOTH, direct: Direction = Direction.BOTH,
depth: int = 3, depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
search_scope: Optional[ search_scope: Optional[
Literal["knowledge_graph", "document_graph"] Literal["knowledge_graph", "document_graph"]
@ -621,11 +614,17 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
mg.append_edge(edge) mg.append_edge(edge)
return mg return mg
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]: # type: ignore[override]
# mypy: ignore-errors
async def stream_query( # type: ignore[override]
self,
query: str,
**kwargs,
) -> AsyncGenerator[Graph, None]:
"""Execute a stream query.""" """Execute a stream query."""
from neo4j import graph from neo4j import graph
async for record in self.graph_store.conn.run_stream(query): async for record in self.graph_store.conn.run_stream(query): # type: ignore
mg = MemoryGraph() mg = MemoryGraph()
for key in record.keys(): for key in record.keys():
value = record[key] value = record[key]
@ -650,15 +649,19 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
rels = list(record["p"].relationships) rels = list(record["p"].relationships)
formatted_path = [] formatted_path = []
for i in range(len(nodes)): for i in range(len(nodes)):
formatted_path.append({ formatted_path.append(
{
"id": nodes[i]._properties["id"], "id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"], "description": nodes[i]._properties["description"],
}) }
)
if i < len(rels): if i < len(rels):
formatted_path.append({ formatted_path.append(
{
"id": rels[i]._properties["id"], "id": rels[i]._properties["id"],
"description": rels[i]._properties["description"], "description": rels[i]._properties["description"],
}) }
)
for i in range(0, len(formatted_path), 2): for i in range(0, len(formatted_path), 2):
mg.upsert_vertex( mg.upsert_vertex(
Vertex( Vertex(

View File

@ -149,25 +149,25 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
return [chunk.chunk_id for chunk in chunks] return [chunk.chunk_id for chunk in chunks]
async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]: async def _aload_document_graph(self, chunks: List[Chunk]) -> None:
"""Load the knowledge graph from the chunks. """Load the knowledge graph from the chunks.
The chunks include the doc structure. The chunks include the doc structure.
""" """
if not self._graph_store.get_config().document_graph_enabled: if not self._graph_store.get_config().document_graph_enabled:
return [] return
chunks: List[ParagraphChunk] = [ _chunks: List[ParagraphChunk] = [
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
] ]
documment_chunk, chunks = self._load_chunks(chunks) documment_chunk, paragraph_chunks = self._load_chunks(_chunks)
# upsert the document and chunks vertices # upsert the document and chunks vertices
self._graph_store_apdater.upsert_documents(iter([documment_chunk])) self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
self._graph_store_apdater.upsert_chunks(iter(chunks)) self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks))
# upsert the document structure # upsert the document structure
for chunk_index, chunk in enumerate(chunks): for chunk_index, chunk in enumerate(paragraph_chunks):
# document -> include -> chunk # document -> include -> chunk
if chunk.parent_is_document: if chunk.parent_is_document:
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk) self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
@ -177,7 +177,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
# chunk -> next -> chunk # chunk -> next -> chunk
if chunk_index >= 1: if chunk_index >= 1:
self._graph_store_apdater.upsert_chunk_next_chunk( self._graph_store_apdater.upsert_chunk_next_chunk(
chunk=chunks[chunk_index - 1], next_chunk=chunk chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
) )
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None: async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
@ -280,7 +280,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
Return: Return:
List[Chunk]: The similar documents. List[Chunk]: The similar documents.
""" """
pass return []
async def asimilar_search_with_scores( async def asimilar_search_with_scores(
self, self,
@ -301,9 +301,6 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
keywords: List[str] = await self._keyword_extractor.extract(text) keywords: List[str] = await self._keyword_extractor.extract(text)
# Local search: extract keywords and explore subgraph # Local search: extract keywords and explore subgraph
subgraph = MemoryGraph()
subgraph_for_doc = MemoryGraph()
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
document_graph_enabled = self._graph_store.get_config().document_graph_enabled document_graph_enabled = self._graph_store.get_config().document_graph_enabled
@ -329,9 +326,10 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
limit=self._config.knowledge_graph_chunk_search_top_size, limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph", search_scope="document_graph",
) )
knowledge_graph_str = subgraph.format() if subgraph else ""
knowledge_graph_str = subgraph.format() knowledge_graph_for_doc_str = (
knowledge_graph_for_doc_str = subgraph_for_doc.format() subgraph_for_doc.format() if subgraph_for_doc else ""
)
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}") logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")

View File

@ -183,5 +183,5 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def delete_by_ids(self, ids: str) -> List[str]: def delete_by_ids(self, ids: str) -> List[str]:
"""Delete by ids.""" """Delete by ids."""
self._graph_store_apdater.delete_document(chunk_ids=ids) self._graph_store_apdater.delete_document(chunk_id=ids)
return [] return []

View File

@ -88,7 +88,8 @@ def __create_community_kg_connector():
async def ask_chunk(chunk: Chunk, question) -> str: async def ask_chunk(chunk: Chunk, question) -> str:
rag_template = ( rag_template = (
"Based on the following [Context] {context}, " "answer [Question] {question}." "Based on the following [Context] {context}, "
"answer [Question] {question}."
) )
template = HumanPromptTemplate.from_template(rag_template) template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(context=chunk.content, question=question) messages = template.format_messages(context=chunk.content, question=question)

View File

@ -46,9 +46,9 @@ def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
assert len(triplets) == 1 assert len(triplets) == 1
def test_query(store: TuGraphStore): def test_query(tugraph_store_adapter: TuGraphStoreAdapter):
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3" query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
result = store.query(query) result = tugraph_store_adapter.query(query)
v_c = result.vertex_count v_c = result.vertex_count
e_c = result.edge_count e_c = result.edge_count
assert v_c == 3 and e_c == 3 assert v_c == 3 and e_c == 3

View File

@ -25,12 +25,12 @@ def g():
(lambda g: g.del_vertices("G", "G"), 6, 9), (lambda g: g.del_vertices("G", "G"), 6, 9),
(lambda g: g.del_vertices("C"), 6, 7), (lambda g: g.del_vertices("C"), 6, 7),
(lambda g: g.del_vertices("A", "G"), 5, 6), (lambda g: g.del_vertices("A", "G"), 5, 6),
(lambda g: g.del_edges("A", "A"), 7, 7), (lambda g: g.del_edges("A", "A", None), 7, 7),
(lambda g: g.del_edges("A", "B"), 7, 8), (lambda g: g.del_edges("A", "B", None), 7, 8),
(lambda g: g.del_edges("A", "A", "0"), 7, 8), (lambda g: g.del_edges("A", "A", "0"), 7, 8),
(lambda g: g.del_edges("E", "F", "8"), 7, 8), (lambda g: g.del_edges("E", "F", "8"), 7, 8),
(lambda g: g.del_edges("E", "F", "9"), 7, 9), (lambda g: g.del_edges("E", "F", "9"), 7, 9),
(lambda g: g.del_edges("E", "F", val=1), 7, 9), (lambda g: g.del_edges("E", "F", None, val=1), 7, 9),
(lambda g: g.del_edges("E", "F", "8", val=1), 7, 9), (lambda g: g.del_edges("E", "F", "8", val=1), 7, 9),
(lambda g: g.del_edges("E", "F", "9", val=1), 7, 9), (lambda g: g.del_edges("E", "F", "9", val=1), 7, 9),
(lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7), (lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),