mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 19:31:43 +00:00
fix: fix unit test error (#2085)
Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
parent
6d6667812b
commit
d9e20426fe
@ -1,7 +1,7 @@
|
|||||||
"""TuGraph Connector."""
|
"""TuGraph Connector."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Generator, List, Tuple, cast
|
from typing import Dict, Generator, Iterator, List, cast
|
||||||
|
|
||||||
from .base import BaseConnector
|
from .base import BaseConnector
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
|
|||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
def create_graph(self, graph_name: str) -> None:
|
def create_graph(self, graph_name: str) -> bool:
|
||||||
"""Create a new graph in the database if it doesn't already exist."""
|
"""Create a new graph in the database if it doesn't already exist."""
|
||||||
try:
|
try:
|
||||||
with self._driver.session(database="default") as session:
|
with self._driver.session(database="default") as session:
|
||||||
@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
|
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
|
||||||
|
|
||||||
|
return not exists
|
||||||
|
|
||||||
def delete_graph(self, graph_name: str) -> None:
|
def delete_graph(self, graph_name: str) -> None:
|
||||||
"""Delete a graph in the database if it exists."""
|
"""Delete a graph in the database if it exists."""
|
||||||
with self._driver.session(database="default") as session:
|
with self._driver.session(database="default") as session:
|
||||||
@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
|
|||||||
"`pip install neo4j`"
|
"`pip install neo4j`"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
def get_table_names(self) -> Tuple[List[str], List[str]]:
|
def get_table_names(self) -> Iterator[str]:
|
||||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||||
with self._driver.session(database=self._graph) as session:
|
with self._driver.session(database=self._graph) as session:
|
||||||
# Run the query to get vertex labels
|
# Run the query to get vertex labels
|
||||||
raw_vertex_labels: Dict[str, str] = session.run(
|
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
|
||||||
"CALL db.vertexLabels()"
|
|
||||||
).data()
|
|
||||||
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
|
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
|
||||||
|
|
||||||
# Run the query to get edge labels
|
# Run the query to get edge labels
|
||||||
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
|
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
|
||||||
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
|
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
|
||||||
|
|
||||||
return vertex_labels, edge_labels
|
return iter(vertex_labels + edge_labels)
|
||||||
|
|
||||||
def get_grants(self):
|
def get_grants(self):
|
||||||
"""Get grants."""
|
"""Get grants."""
|
||||||
|
@ -76,8 +76,8 @@ def _parse_db_summary(
|
|||||||
table_info_summaries = None
|
table_info_summaries = None
|
||||||
if isinstance(conn, TuGraphConnector):
|
if isinstance(conn, TuGraphConnector):
|
||||||
table_names = conn.get_table_names()
|
table_names = conn.get_table_names()
|
||||||
v_tables = table_names.get("vertex_tables", [])
|
v_tables = table_names.get("vertex_tables", []) # type: ignore
|
||||||
e_tables = table_names.get("edge_tables", [])
|
e_tables = table_names.get("edge_tables", []) # type: ignore
|
||||||
table_info_summaries = [
|
table_info_summaries = [
|
||||||
_parse_table_summary(conn, summary_template, table_name, "vertex")
|
_parse_table_summary(conn, summary_template, table_name, "vertex")
|
||||||
for table_name in v_tables
|
for table_name in v_tables
|
||||||
|
@ -141,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
|
|||||||
if len(missing_plugins):
|
if len(missing_plugins):
|
||||||
for name in missing_plugins:
|
for name in missing_plugins:
|
||||||
try:
|
try:
|
||||||
from dbgpt_tugraph_plugins import (
|
from dbgpt_tugraph_plugins import ( # type: ignore
|
||||||
get_plugin_binary_path, # type:ignore[import-untyped]
|
get_plugin_binary_path,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -150,7 +150,7 @@ class TuGraphStore(GraphStoreBase):
|
|||||||
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
|
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
|
||||||
"https://pypi.org/simple"
|
"https://pypi.org/simple"
|
||||||
)
|
)
|
||||||
plugin_path = get_plugin_binary_path("leiden")
|
plugin_path = get_plugin_binary_path("leiden") # type: ignore
|
||||||
with open(plugin_path, "rb") as f:
|
with open(plugin_path, "rb") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
content = base64.b64encode(content).decode()
|
content = base64.b64encode(content).decode()
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import AsyncGenerator, Iterator, List, Optional, Union
|
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
|
||||||
|
|
||||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||||
from dbgpt.storage.graph_store.graph import (
|
from dbgpt.storage.graph_store.graph import (
|
||||||
@ -156,7 +156,11 @@ class GraphStoreAdapter(ABC):
|
|||||||
"""Create graph."""
|
"""Create graph."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_graph_label(self) -> None:
|
def create_graph_label(
|
||||||
|
self,
|
||||||
|
graph_elem_type: GraphElemType,
|
||||||
|
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||||
|
) -> None:
|
||||||
"""Create a graph label.
|
"""Create a graph label.
|
||||||
|
|
||||||
The graph label is used to identify and distinguish different types of nodes
|
The graph label is used to identify and distinguish different types of nodes
|
||||||
@ -176,7 +180,12 @@ class GraphStoreAdapter(ABC):
|
|||||||
self,
|
self,
|
||||||
subs: List[str],
|
subs: List[str],
|
||||||
direct: Direction = Direction.BOTH,
|
direct: Direction = Direction.BOTH,
|
||||||
depth: Optional[int] = None,
|
depth: int = 3,
|
||||||
|
fan: Optional[int] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
search_scope: Optional[
|
||||||
|
Literal["knowledge_graph", "document_graph"]
|
||||||
|
] = "knowledge_graph",
|
||||||
) -> MemoryGraph:
|
) -> MemoryGraph:
|
||||||
"""Explore the graph from given subjects up to a depth."""
|
"""Explore the graph from given subjects up to a depth."""
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from dbgpt.storage.graph_store.graph import (
|
from dbgpt.storage.graph_store.graph import (
|
||||||
Direction,
|
Direction,
|
||||||
@ -173,6 +173,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
|
|
||||||
def create_graph_label(
|
def create_graph_label(
|
||||||
self,
|
self,
|
||||||
|
graph_elem_type: GraphElemType,
|
||||||
|
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a graph label.
|
"""Create a graph label.
|
||||||
|
|
||||||
@ -201,9 +203,12 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
self,
|
self,
|
||||||
subs: List[str],
|
subs: List[str],
|
||||||
direct: Direction = Direction.BOTH,
|
direct: Direction = Direction.BOTH,
|
||||||
depth: int | None = None,
|
depth: int = 3,
|
||||||
fan: int | None = None,
|
fan: Optional[int] = None,
|
||||||
limit: int | None = None,
|
limit: Optional[int] = None,
|
||||||
|
search_scope: Optional[
|
||||||
|
Literal["knowledge_graph", "document_graph"]
|
||||||
|
] = "knowledge_graph",
|
||||||
) -> MemoryGraph:
|
) -> MemoryGraph:
|
||||||
"""Explore the graph from given subjects up to a depth."""
|
"""Explore the graph from given subjects up to a depth."""
|
||||||
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
||||||
|
@ -79,7 +79,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
@property
|
@property
|
||||||
def graph_store(self) -> TuGraphStore:
|
def graph_store(self) -> TuGraphStore:
|
||||||
"""Get the graph store."""
|
"""Get the graph store."""
|
||||||
return self._graph_store
|
return self._graph_store # type: ignore[return-value]
|
||||||
|
|
||||||
def get_graph_config(self):
|
def get_graph_config(self):
|
||||||
"""Get the graph store config."""
|
"""Get the graph store config."""
|
||||||
@ -176,29 +176,23 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
[{self._convert_dict_to_str(edge_list)}])"""
|
[{self._convert_dict_to_str(edge_list)}])"""
|
||||||
self.graph_store.conn.run(query=relation_query)
|
self.graph_store.conn.run(query=relation_query)
|
||||||
|
|
||||||
def upsert_chunks(
|
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
|
||||||
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
|
|
||||||
) -> None:
|
|
||||||
"""Upsert chunks."""
|
"""Upsert chunks."""
|
||||||
chunks_list = list(chunks)
|
|
||||||
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
|
|
||||||
chunk_list = [
|
chunk_list = [
|
||||||
{
|
{
|
||||||
"id": self._escape_quotes(chunk.chunk_id),
|
"id": self._escape_quotes(chunk.chunk_id),
|
||||||
"name": self._escape_quotes(chunk.chunk_name),
|
"name": self._escape_quotes(chunk.chunk_name),
|
||||||
"content": self._escape_quotes(chunk.content),
|
"content": self._escape_quotes(chunk.content),
|
||||||
}
|
}
|
||||||
for chunk in chunks_list
|
if isinstance(chunk, ParagraphChunk)
|
||||||
]
|
else {
|
||||||
else:
|
|
||||||
chunk_list = [
|
|
||||||
{
|
|
||||||
"id": self._escape_quotes(chunk.vid),
|
"id": self._escape_quotes(chunk.vid),
|
||||||
"name": self._escape_quotes(chunk.name),
|
"name": self._escape_quotes(chunk.name),
|
||||||
"content": self._escape_quotes(chunk.get_prop("content")),
|
"content": self._escape_quotes(chunk.get_prop("content")),
|
||||||
}
|
}
|
||||||
for chunk in chunks_list
|
for chunk in chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
chunk_query = (
|
chunk_query = (
|
||||||
f"CALL db.upsertVertex("
|
f"CALL db.upsertVertex("
|
||||||
f'"{GraphElemType.CHUNK.value}", '
|
f'"{GraphElemType.CHUNK.value}", '
|
||||||
@ -207,28 +201,24 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
self.graph_store.conn.run(query=chunk_query)
|
self.graph_store.conn.run(query=chunk_query)
|
||||||
|
|
||||||
def upsert_documents(
|
def upsert_documents(
|
||||||
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
|
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Upsert documents."""
|
"""Upsert documents."""
|
||||||
documents_list = list(documents)
|
|
||||||
if documents_list and isinstance(documents_list[0], ParagraphChunk):
|
|
||||||
document_list = [
|
document_list = [
|
||||||
{
|
{
|
||||||
"id": self._escape_quotes(document.chunk_id),
|
"id": self._escape_quotes(document.chunk_id),
|
||||||
"name": self._escape_quotes(document.chunk_name),
|
"name": self._escape_quotes(document.chunk_name),
|
||||||
"content": "",
|
"content": "",
|
||||||
}
|
}
|
||||||
for document in documents_list
|
if isinstance(document, ParagraphChunk)
|
||||||
]
|
else {
|
||||||
else:
|
|
||||||
document_list = [
|
|
||||||
{
|
|
||||||
"id": self._escape_quotes(document.vid),
|
"id": self._escape_quotes(document.vid),
|
||||||
"name": self._escape_quotes(document.name),
|
"name": self._escape_quotes(document.name),
|
||||||
"content": self._escape_quotes(document.get_prop("content")) or "",
|
"content": "",
|
||||||
}
|
}
|
||||||
for document in documents_list
|
for document in documents
|
||||||
]
|
]
|
||||||
|
|
||||||
document_query = (
|
document_query = (
|
||||||
"CALL db.upsertVertex("
|
"CALL db.upsertVertex("
|
||||||
f'"{GraphElemType.DOCUMENT.value}", '
|
f'"{GraphElemType.DOCUMENT.value}", '
|
||||||
@ -258,7 +248,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
self.graph_store.conn.run(query=vertex_query)
|
self.graph_store.conn.run(query=vertex_query)
|
||||||
self.graph_store.conn.run(query=edge_query)
|
self.graph_store.conn.run(query=edge_query)
|
||||||
|
|
||||||
def upsert_graph(self, graph: MemoryGraph) -> None:
|
def upsert_graph(self, graph: Graph) -> None:
|
||||||
"""Add graph to the graph store.
|
"""Add graph to the graph store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -362,7 +352,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
|
|
||||||
def create_graph(self, graph_name: str):
|
def create_graph(self, graph_name: str):
|
||||||
"""Create a graph."""
|
"""Create a graph."""
|
||||||
self.graph_store.conn.create_graph(graph_name=graph_name)
|
if not self.graph_store.conn.create_graph(graph_name=graph_name):
|
||||||
|
return
|
||||||
|
|
||||||
# Create the graph schema
|
# Create the graph schema
|
||||||
def _format_graph_propertity_schema(
|
def _format_graph_propertity_schema(
|
||||||
@ -474,12 +465,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
(vertices) and edges in the graph.
|
(vertices) and edges in the graph.
|
||||||
"""
|
"""
|
||||||
if graph_elem_type.is_vertex(): # vertex
|
if graph_elem_type.is_vertex(): # vertex
|
||||||
data = json.dumps({
|
data = json.dumps(
|
||||||
|
{
|
||||||
"label": graph_elem_type.value,
|
"label": graph_elem_type.value,
|
||||||
"type": "VERTEX",
|
"type": "VERTEX",
|
||||||
"primary": "id",
|
"primary": "id",
|
||||||
"properties": graph_properties,
|
"properties": graph_properties,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||||
else: # edge
|
else: # edge
|
||||||
|
|
||||||
@ -505,12 +498,14 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid graph element type.")
|
raise ValueError("Invalid graph element type.")
|
||||||
|
|
||||||
data = json.dumps({
|
data = json.dumps(
|
||||||
|
{
|
||||||
"label": graph_elem_type.value,
|
"label": graph_elem_type.value,
|
||||||
"type": "EDGE",
|
"type": "EDGE",
|
||||||
"constraints": edge_direction(graph_elem_type),
|
"constraints": edge_direction(graph_elem_type),
|
||||||
"properties": graph_properties,
|
"properties": graph_properties,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||||
|
|
||||||
self.graph_store.conn.run(gql)
|
self.graph_store.conn.run(gql)
|
||||||
@ -530,18 +525,16 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
True if the label exists in the specified graph element type, otherwise
|
True if the label exists in the specified graph element type, otherwise
|
||||||
False.
|
False.
|
||||||
"""
|
"""
|
||||||
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
|
tables = self.graph_store.conn.get_table_names()
|
||||||
|
|
||||||
if graph_elem_type.is_vertex():
|
return graph_elem_type.value in tables
|
||||||
return graph_elem_type in vertex_tables
|
|
||||||
else:
|
|
||||||
return graph_elem_type in edge_tables
|
|
||||||
|
|
||||||
def explore(
|
def explore(
|
||||||
self,
|
self,
|
||||||
subs: List[str],
|
subs: List[str],
|
||||||
direct: Direction = Direction.BOTH,
|
direct: Direction = Direction.BOTH,
|
||||||
depth: int = 3,
|
depth: int = 3,
|
||||||
|
fan: Optional[int] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
search_scope: Optional[
|
search_scope: Optional[
|
||||||
Literal["knowledge_graph", "document_graph"]
|
Literal["knowledge_graph", "document_graph"]
|
||||||
@ -621,11 +614,17 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
mg.append_edge(edge)
|
mg.append_edge(edge)
|
||||||
return mg
|
return mg
|
||||||
|
|
||||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
# type: ignore[override]
|
||||||
|
# mypy: ignore-errors
|
||||||
|
async def stream_query( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[Graph, None]:
|
||||||
"""Execute a stream query."""
|
"""Execute a stream query."""
|
||||||
from neo4j import graph
|
from neo4j import graph
|
||||||
|
|
||||||
async for record in self.graph_store.conn.run_stream(query):
|
async for record in self.graph_store.conn.run_stream(query): # type: ignore
|
||||||
mg = MemoryGraph()
|
mg = MemoryGraph()
|
||||||
for key in record.keys():
|
for key in record.keys():
|
||||||
value = record[key]
|
value = record[key]
|
||||||
@ -650,15 +649,19 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
rels = list(record["p"].relationships)
|
rels = list(record["p"].relationships)
|
||||||
formatted_path = []
|
formatted_path = []
|
||||||
for i in range(len(nodes)):
|
for i in range(len(nodes)):
|
||||||
formatted_path.append({
|
formatted_path.append(
|
||||||
|
{
|
||||||
"id": nodes[i]._properties["id"],
|
"id": nodes[i]._properties["id"],
|
||||||
"description": nodes[i]._properties["description"],
|
"description": nodes[i]._properties["description"],
|
||||||
})
|
}
|
||||||
|
)
|
||||||
if i < len(rels):
|
if i < len(rels):
|
||||||
formatted_path.append({
|
formatted_path.append(
|
||||||
|
{
|
||||||
"id": rels[i]._properties["id"],
|
"id": rels[i]._properties["id"],
|
||||||
"description": rels[i]._properties["description"],
|
"description": rels[i]._properties["description"],
|
||||||
})
|
}
|
||||||
|
)
|
||||||
for i in range(0, len(formatted_path), 2):
|
for i in range(0, len(formatted_path), 2):
|
||||||
mg.upsert_vertex(
|
mg.upsert_vertex(
|
||||||
Vertex(
|
Vertex(
|
||||||
|
@ -149,25 +149,25 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
|
|
||||||
return [chunk.chunk_id for chunk in chunks]
|
return [chunk.chunk_id for chunk in chunks]
|
||||||
|
|
||||||
async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]:
|
async def _aload_document_graph(self, chunks: List[Chunk]) -> None:
|
||||||
"""Load the knowledge graph from the chunks.
|
"""Load the knowledge graph from the chunks.
|
||||||
|
|
||||||
The chunks include the doc structure.
|
The chunks include the doc structure.
|
||||||
"""
|
"""
|
||||||
if not self._graph_store.get_config().document_graph_enabled:
|
if not self._graph_store.get_config().document_graph_enabled:
|
||||||
return []
|
return
|
||||||
|
|
||||||
chunks: List[ParagraphChunk] = [
|
_chunks: List[ParagraphChunk] = [
|
||||||
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
|
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
|
||||||
]
|
]
|
||||||
documment_chunk, chunks = self._load_chunks(chunks)
|
documment_chunk, paragraph_chunks = self._load_chunks(_chunks)
|
||||||
|
|
||||||
# upsert the document and chunks vertices
|
# upsert the document and chunks vertices
|
||||||
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
|
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
|
||||||
self._graph_store_apdater.upsert_chunks(iter(chunks))
|
self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks))
|
||||||
|
|
||||||
# upsert the document structure
|
# upsert the document structure
|
||||||
for chunk_index, chunk in enumerate(chunks):
|
for chunk_index, chunk in enumerate(paragraph_chunks):
|
||||||
# document -> include -> chunk
|
# document -> include -> chunk
|
||||||
if chunk.parent_is_document:
|
if chunk.parent_is_document:
|
||||||
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
|
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
|
||||||
@ -177,7 +177,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
# chunk -> next -> chunk
|
# chunk -> next -> chunk
|
||||||
if chunk_index >= 1:
|
if chunk_index >= 1:
|
||||||
self._graph_store_apdater.upsert_chunk_next_chunk(
|
self._graph_store_apdater.upsert_chunk_next_chunk(
|
||||||
chunk=chunks[chunk_index - 1], next_chunk=chunk
|
chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
|
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
|
||||||
@ -280,7 +280,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
Return:
|
Return:
|
||||||
List[Chunk]: The similar documents.
|
List[Chunk]: The similar documents.
|
||||||
"""
|
"""
|
||||||
pass
|
return []
|
||||||
|
|
||||||
async def asimilar_search_with_scores(
|
async def asimilar_search_with_scores(
|
||||||
self,
|
self,
|
||||||
@ -301,9 +301,6 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||||
|
|
||||||
# Local search: extract keywords and explore subgraph
|
# Local search: extract keywords and explore subgraph
|
||||||
subgraph = MemoryGraph()
|
|
||||||
subgraph_for_doc = MemoryGraph()
|
|
||||||
|
|
||||||
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
|
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
|
||||||
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
|
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
|
||||||
|
|
||||||
@ -329,9 +326,10 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
limit=self._config.knowledge_graph_chunk_search_top_size,
|
limit=self._config.knowledge_graph_chunk_search_top_size,
|
||||||
search_scope="document_graph",
|
search_scope="document_graph",
|
||||||
)
|
)
|
||||||
|
knowledge_graph_str = subgraph.format() if subgraph else ""
|
||||||
knowledge_graph_str = subgraph.format()
|
knowledge_graph_for_doc_str = (
|
||||||
knowledge_graph_for_doc_str = subgraph_for_doc.format()
|
subgraph_for_doc.format() if subgraph_for_doc else ""
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
|
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
|
||||||
|
|
||||||
|
@ -183,5 +183,5 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
|||||||
|
|
||||||
def delete_by_ids(self, ids: str) -> List[str]:
|
def delete_by_ids(self, ids: str) -> List[str]:
|
||||||
"""Delete by ids."""
|
"""Delete by ids."""
|
||||||
self._graph_store_apdater.delete_document(chunk_ids=ids)
|
self._graph_store_apdater.delete_document(chunk_id=ids)
|
||||||
return []
|
return []
|
||||||
|
@ -88,7 +88,8 @@ def __create_community_kg_connector():
|
|||||||
|
|
||||||
async def ask_chunk(chunk: Chunk, question) -> str:
|
async def ask_chunk(chunk: Chunk, question) -> str:
|
||||||
rag_template = (
|
rag_template = (
|
||||||
"Based on the following [Context] {context}, " "answer [Question] {question}."
|
"Based on the following [Context] {context}, "
|
||||||
|
"answer [Question] {question}."
|
||||||
)
|
)
|
||||||
template = HumanPromptTemplate.from_template(rag_template)
|
template = HumanPromptTemplate.from_template(rag_template)
|
||||||
messages = template.format_messages(context=chunk.content, question=question)
|
messages = template.format_messages(context=chunk.content, question=question)
|
||||||
|
@ -46,9 +46,9 @@ def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
|
|||||||
assert len(triplets) == 1
|
assert len(triplets) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_query(store: TuGraphStore):
|
def test_query(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||||
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
|
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
|
||||||
result = store.query(query)
|
result = tugraph_store_adapter.query(query)
|
||||||
v_c = result.vertex_count
|
v_c = result.vertex_count
|
||||||
e_c = result.edge_count
|
e_c = result.edge_count
|
||||||
assert v_c == 3 and e_c == 3
|
assert v_c == 3 and e_c == 3
|
||||||
|
@ -25,12 +25,12 @@ def g():
|
|||||||
(lambda g: g.del_vertices("G", "G"), 6, 9),
|
(lambda g: g.del_vertices("G", "G"), 6, 9),
|
||||||
(lambda g: g.del_vertices("C"), 6, 7),
|
(lambda g: g.del_vertices("C"), 6, 7),
|
||||||
(lambda g: g.del_vertices("A", "G"), 5, 6),
|
(lambda g: g.del_vertices("A", "G"), 5, 6),
|
||||||
(lambda g: g.del_edges("A", "A"), 7, 7),
|
(lambda g: g.del_edges("A", "A", None), 7, 7),
|
||||||
(lambda g: g.del_edges("A", "B"), 7, 8),
|
(lambda g: g.del_edges("A", "B", None), 7, 8),
|
||||||
(lambda g: g.del_edges("A", "A", "0"), 7, 8),
|
(lambda g: g.del_edges("A", "A", "0"), 7, 8),
|
||||||
(lambda g: g.del_edges("E", "F", "8"), 7, 8),
|
(lambda g: g.del_edges("E", "F", "8"), 7, 8),
|
||||||
(lambda g: g.del_edges("E", "F", "9"), 7, 9),
|
(lambda g: g.del_edges("E", "F", "9"), 7, 9),
|
||||||
(lambda g: g.del_edges("E", "F", val=1), 7, 9),
|
(lambda g: g.del_edges("E", "F", None, val=1), 7, 9),
|
||||||
(lambda g: g.del_edges("E", "F", "8", val=1), 7, 9),
|
(lambda g: g.del_edges("E", "F", "8", val=1), 7, 9),
|
||||||
(lambda g: g.del_edges("E", "F", "9", val=1), 7, 9),
|
(lambda g: g.del_edges("E", "F", "9", val=1), 7, 9),
|
||||||
(lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),
|
(lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),
|
||||||
|
Loading…
Reference in New Issue
Block a user