mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
feat: add document structure into GraphRAG (#2033)
Co-authored-by: Appointat <kuda.czk@antgroup.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: vritser <vritser@163.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""Graph store base class."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.storage.graph_store.graph import Direction, Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel):
|
||||
default=None,
|
||||
description="The embedding function of graph store, optional.",
|
||||
)
|
||||
summary_enabled: bool = Field(
|
||||
enable_summary: bool = Field(
|
||||
default=False,
|
||||
description="Enable graph community summary or not.",
|
||||
)
|
||||
enable_document_graph: bool = Field(
|
||||
default=True,
|
||||
description="Enable document graph search or not.",
|
||||
)
|
||||
enable_triplet_graph: bool = Field(
|
||||
default=True,
|
||||
description="Enable knowledge graph search or not.",
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreBase(ABC):
|
||||
"""Graph store base class."""
|
||||
|
||||
def __init__(self, config: GraphStoreConfig):
|
||||
"""Initialize graph store."""
|
||||
self._config = config
|
||||
self._conn = None
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> GraphStoreConfig:
|
||||
"""Get the graph store config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
def _escape_quotes(self, text: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Add triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore on graph."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query."""
|
||||
|
||||
def aquery(self, query: str, **args) -> Graph:
|
||||
"""Async execute a query."""
|
||||
return self.query(query, **args)
|
||||
|
||||
@abstractmethod
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
# @abstractmethod
|
||||
# def _paser(self, entities: List[Vertex]) -> str:
|
||||
# """Parse entities to string."""
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Graph store factory."""
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Graph definition."""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -6,13 +7,41 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphElemType(Enum):
|
||||
"""Type of element in graph."""
|
||||
|
||||
DOCUMENT = "document"
|
||||
CHUNK = "chunk"
|
||||
ENTITY = "entity" # view as general vertex in the general case
|
||||
RELATION = "relation" # view as general edge in the general case
|
||||
INCLUDE = "include"
|
||||
NEXT = "next"
|
||||
|
||||
DOCUMENT_INCLUDE_CHUNK = "document_include_chunk"
|
||||
CHUNK_INCLUDE_CHUNK = "chunk_include_chunk"
|
||||
CHUNK_INCLUDE_ENTITY = "chunk_include_entity"
|
||||
CHUNK_NEXT_CHUNK = "chunk_next_chunk"
|
||||
|
||||
def is_vertex(self) -> bool:
|
||||
"""Check if the element is a vertex."""
|
||||
return self in [
|
||||
GraphElemType.DOCUMENT,
|
||||
GraphElemType.CHUNK,
|
||||
GraphElemType.ENTITY,
|
||||
]
|
||||
|
||||
def is_edge(self) -> bool:
|
||||
"""Check if the element is an edge."""
|
||||
return not self.is_vertex()
|
||||
|
||||
|
||||
class Direction(Enum):
|
||||
"""Direction class."""
|
||||
|
||||
@@ -41,7 +70,7 @@ class Elem(ABC):
|
||||
|
||||
def set_prop(self, key: str, value: Any):
|
||||
"""Set a property of ELem."""
|
||||
self._props[key] = value
|
||||
self._props[key] = value # note: always update the value
|
||||
|
||||
def get_prop(self, key: str):
|
||||
"""Get one of the properties of Elem."""
|
||||
@@ -124,6 +153,18 @@ class Edge(Elem):
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Check if two edges are equal.
|
||||
|
||||
Let's say two edges are equal if they have the same source vertex ID,
|
||||
target vertex ID, and edge label. The properties are not considered.
|
||||
"""
|
||||
return (self.sid, self.tid, self.name) == (other.sid, other.tid, other.name)
|
||||
|
||||
def __hash__(self):
|
||||
"""Return the hash value of the edge."""
|
||||
return hash((self.sid, self.tid, self.name))
|
||||
|
||||
@property
|
||||
def sid(self) -> str:
|
||||
"""Return the source vertex ID of the edge."""
|
||||
@@ -188,11 +229,15 @@ class Graph(ABC):
|
||||
"""Get neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
def vertices(
|
||||
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
|
||||
) -> Iterator[Vertex]:
|
||||
"""Get vertex iterator."""
|
||||
|
||||
@abstractmethod
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
def edges(
|
||||
self, filter_fn: Optional[Callable[[Edge], bool]] = None
|
||||
) -> Iterator[Edge]:
|
||||
"""Get edge iterator."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -241,7 +286,7 @@ class MemoryGraph(Graph):
|
||||
self._edge_prop_keys = set()
|
||||
self._edge_count = 0
|
||||
|
||||
# init vertices, out edges, in edges index
|
||||
# vertices index, out edges index, in edges index
|
||||
self._vs: Any = defaultdict()
|
||||
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
||||
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
||||
@@ -269,7 +314,7 @@ class MemoryGraph(Graph):
|
||||
# update metadata
|
||||
self._vertex_prop_keys.update(vertex.props.keys())
|
||||
|
||||
def append_edge(self, edge: Edge):
|
||||
def append_edge(self, edge: Edge) -> bool:
|
||||
"""Append an edge if it doesn't exist; requires edge label."""
|
||||
sid = edge.sid
|
||||
tid = edge.tid
|
||||
@@ -290,6 +335,34 @@ class MemoryGraph(Graph):
|
||||
self._edge_count += 1
|
||||
return True
|
||||
|
||||
def upsert_vertex_and_edge(
|
||||
self,
|
||||
src_vid: str,
|
||||
src_name: str,
|
||||
src_props: Dict[str, Any],
|
||||
dst_vid: str,
|
||||
dst_name: str,
|
||||
dst_props: Dict[str, Any],
|
||||
edge_name: str,
|
||||
edge_type: str,
|
||||
):
|
||||
"""Uperst src and dst vertex, and edge."""
|
||||
src_vertex = Vertex(src_vid, src_name, **src_props)
|
||||
dst_vertex = Vertex(dst_vid, dst_name, **dst_props)
|
||||
edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type})
|
||||
|
||||
self.upsert_vertex(src_vertex)
|
||||
self.upsert_vertex(dst_vertex)
|
||||
self.append_edge(edge)
|
||||
|
||||
def upsert_graph(self, graph: "MemoryGraph"):
|
||||
"""Upsert a graph."""
|
||||
for vertex in graph.vertices():
|
||||
self.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self.append_edge(edge)
|
||||
|
||||
def has_vertex(self, vid: str) -> bool:
|
||||
"""Retrieve a vertex by ID."""
|
||||
return vid in self._vs
|
||||
@@ -335,13 +408,26 @@ class MemoryGraph(Graph):
|
||||
|
||||
return itertools.islice(es, limit) if limit else es
|
||||
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
def vertices(
|
||||
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
|
||||
) -> Iterator[Vertex]:
|
||||
"""Return vertices."""
|
||||
return iter(self._vs.values())
|
||||
# Get all vertices in the graph
|
||||
all_vertices = self._vs.values()
|
||||
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
return all_vertices if filter_fn is None else filter(filter_fn, all_vertices)
|
||||
|
||||
def edges(
|
||||
self, filter_fn: Optional[Callable[[Edge], bool]] = None
|
||||
) -> Iterator[Edge]:
|
||||
"""Return edges."""
|
||||
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
||||
# Get all edges in the graph
|
||||
all_edges = (e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
||||
|
||||
if filter_fn is None:
|
||||
return all_edges
|
||||
else:
|
||||
return filter(filter_fn, all_edges)
|
||||
|
||||
def del_vertices(self, *vids: str):
|
||||
"""Delete specified vertices."""
|
||||
@@ -353,7 +439,7 @@ class MemoryGraph(Graph):
|
||||
"""Delete edges."""
|
||||
old_edge_cnt = len(self._oes[sid][tid])
|
||||
|
||||
def remove_matches(es):
|
||||
def remove_matches(es: Set[Edge]):
|
||||
return set(
|
||||
filter(
|
||||
lambda e: not (
|
||||
@@ -452,7 +538,7 @@ class MemoryGraph(Graph):
|
||||
]
|
||||
}
|
||||
|
||||
def format(self) -> str:
|
||||
def format(self, entities_only: Optional[bool] = False) -> str:
|
||||
"""Format graph to string."""
|
||||
vs_str = "\n".join(v.format() for v in self.vertices())
|
||||
es_str = "\n".join(
|
||||
@@ -461,11 +547,14 @@ class MemoryGraph(Graph):
|
||||
f"{self.get_vertex(e.tid).format(concise=True)}"
|
||||
for e in self.edges()
|
||||
)
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
if entities_only:
|
||||
return f"Entities:\n{vs_str}" if vs_str else ""
|
||||
else:
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\nRelationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
|
@@ -1,11 +1,10 @@
|
||||
"""Memory graph store."""
|
||||
import json
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,77 +27,8 @@ class MemoryGraphStore(GraphStoreBase):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store_config
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
raise NotImplementedError("Memory graph store does not have edge type")
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert a triplet into the graph."""
|
||||
self._graph.append_edge(Edge(sub, obj, rel))
|
||||
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
for vertex in graph.vertices():
|
||||
self._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph.append_edge(edge)
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Retrieve triplets originating from a subject."""
|
||||
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete a specific triplet from the graph."""
|
||||
self._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
self._graph.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
self._graph = None
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Return the graph schema as a JSON string."""
|
||||
return json.dumps(self._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Return self."""
|
||||
if not limit:
|
||||
return self._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
raise NotImplementedError("Query memory graph not allowed")
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
raise NotImplementedError("Stream query memory graph not allowed")
|
||||
def _escape_quotes(self, text: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
raise NotImplementedError(
|
||||
"_escape_quotes is not implemented by MemoryGraphStore"
|
||||
)
|
||||
|
@@ -1,14 +1,15 @@
|
||||
"""TuGraph store."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Generator, Iterator, List, Optional, Tuple
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.datasource.conn_tugraph import TuGraphConnector
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
|
||||
from dbgpt.storage.graph_store.graph import GraphElemType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,12 +36,28 @@ class TuGraphStoreConfig(GraphStoreConfig):
|
||||
description="login password",
|
||||
)
|
||||
vertex_type: str = Field(
|
||||
default="entity",
|
||||
description="The type of vertex, `entity` by default.",
|
||||
default=GraphElemType.ENTITY.value,
|
||||
description="The type of entity vertex, `entity` by default.",
|
||||
)
|
||||
document_type: str = Field(
|
||||
default=GraphElemType.DOCUMENT.value,
|
||||
description="The type of document vertex, `document` by default.",
|
||||
)
|
||||
chunk_type: str = Field(
|
||||
default=GraphElemType.CHUNK.value,
|
||||
description="The type of chunk vertex, `relation` by default.",
|
||||
)
|
||||
edge_type: str = Field(
|
||||
default="relation",
|
||||
description="The type of edge, `relation` by default.",
|
||||
default=GraphElemType.RELATION.value,
|
||||
description="The type of relation edge, `relation` by default.",
|
||||
)
|
||||
include_type: str = Field(
|
||||
default=GraphElemType.INCLUDE.value,
|
||||
description="The type of include edge, `include` by default.",
|
||||
)
|
||||
next_type: str = Field(
|
||||
default=GraphElemType.NEXT.value,
|
||||
description="The type of next edge, `next` by default.",
|
||||
)
|
||||
plugin_names: List[str] = Field(
|
||||
default=["leiden"],
|
||||
@@ -62,17 +79,24 @@ class TuGraphStore(GraphStoreBase):
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
|
||||
self._summary_enabled = (
|
||||
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
|
||||
or config.summary_enabled
|
||||
self._enable_summary = (
|
||||
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
|
||||
or config.enable_summary
|
||||
)
|
||||
self._enable_document_graph = (
|
||||
os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true"
|
||||
or config.enable_document_graph
|
||||
)
|
||||
self._enable_triplet_graph = (
|
||||
os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true"
|
||||
or config.enable_triplet_graph
|
||||
)
|
||||
self._plugin_names = (
|
||||
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
|
||||
or config.plugin_names
|
||||
)
|
||||
|
||||
self._graph_name = config.name
|
||||
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
|
||||
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
|
||||
|
||||
self.conn = TuGraphConnector.from_uri_db(
|
||||
host=self._host,
|
||||
@@ -82,34 +106,29 @@ class TuGraphStore(GraphStoreBase):
|
||||
db_name=config.name,
|
||||
)
|
||||
|
||||
self._create_graph(config.name)
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return self._vertex_type
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return self._edge_type
|
||||
|
||||
def _create_graph(self, graph_name: str):
|
||||
self.conn.create_graph(graph_name=graph_name)
|
||||
self._create_schema()
|
||||
if self._summary_enabled:
|
||||
self._upload_plugin()
|
||||
|
||||
def _check_label(self, elem_type: str):
|
||||
result = self.conn.get_table_names()
|
||||
if elem_type == "vertex":
|
||||
return self._vertex_type in result["vertex_tables"]
|
||||
if elem_type == "edge":
|
||||
return self._edge_type in result["edge_tables"]
|
||||
def get_config(self) -> TuGraphStoreConfig:
|
||||
"""Get the TuGraph store config."""
|
||||
return self._config
|
||||
|
||||
def _add_vertex_index(self, field_name):
|
||||
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
|
||||
"""Add an index to the vertex table."""
|
||||
# TODO: Not used in the current implementation.
|
||||
gql = f"CALL db.addIndex('{GraphElemType.ENTITY.value}', '{field_name}', false)"
|
||||
self.conn.run(gql)
|
||||
|
||||
def _upload_plugin(self):
|
||||
"""Upload missing plugins to the TuGraph database.
|
||||
|
||||
This method checks for the presence of required plugins in the database and
|
||||
uploads any missing plugins. It performs the following steps:
|
||||
1. Lists existing plugins in the database.
|
||||
2. Identifies missing plugins by comparing with the required plugin list.
|
||||
3. For each missing plugin, reads its binary content, encodes it, and uploads to
|
||||
the database.
|
||||
|
||||
The method uses the 'leiden' plugin as an example, but can be extended for other
|
||||
plugins.
|
||||
"""
|
||||
gql = "CALL db.plugin.listPlugin('CPP','v1')"
|
||||
result = self.conn.run(gql)
|
||||
result_names = [
|
||||
@@ -122,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
|
||||
if len(missing_plugins):
|
||||
for name in missing_plugins:
|
||||
try:
|
||||
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
|
||||
get_plugin_binary_path,
|
||||
from dbgpt_tugraph_plugins import (
|
||||
get_plugin_binary_path, # type:ignore[import-untyped]
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
@@ -136,375 +155,12 @@ class TuGraphStore(GraphStoreBase):
|
||||
content = f.read()
|
||||
content = base64.b64encode(content).decode()
|
||||
gql = (
|
||||
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
|
||||
"'SO', '{name} Plugin', false, 'v1')"
|
||||
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', 'SO', "
|
||||
f"'{name} Plugin', false, 'v1')"
|
||||
)
|
||||
self.conn.run(gql)
|
||||
|
||||
def _create_schema(self):
|
||||
if not self._check_label("vertex"):
|
||||
if self._summary_enabled:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false],"
|
||||
f"['_document_id',string,true],"
|
||||
f"['_chunk_id',string,true],"
|
||||
f"['_community_id',string,true],"
|
||||
f"['description',string,true])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
self._add_vertex_index("_community_id")
|
||||
else:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
|
||||
if not self._check_label("edge"):
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false])"""
|
||||
if self._summary_enabled:
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false],
|
||||
["description",STRING,true])"""
|
||||
self.conn.run(create_edge_gql)
|
||||
|
||||
def _format_query_data(self, data, white_prop_list: List[str]):
|
||||
nodes_list = []
|
||||
rels_list: List[Any] = []
|
||||
_white_list = white_prop_list
|
||||
from neo4j import graph
|
||||
|
||||
def get_filtered_properties(properties, white_list):
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
def process_node(node: graph.Node):
|
||||
node_id = node._properties.get("id")
|
||||
node_name = node._properties.get("name")
|
||||
node_properties = get_filtered_properties(node._properties, _white_list)
|
||||
nodes_list.append(
|
||||
{"id": node_id, "name": node_name, "properties": node_properties}
|
||||
)
|
||||
|
||||
def process_relationship(rel: graph.Relationship):
|
||||
name = rel._properties.get("name", "")
|
||||
rel_nodes = rel.nodes
|
||||
src_id = rel_nodes[0]._properties.get("id")
|
||||
dst_id = rel_nodes[1]._properties.get("id")
|
||||
for node in rel_nodes:
|
||||
process_node(node)
|
||||
edge_properties = get_filtered_properties(rel._properties, _white_list)
|
||||
if not any(
|
||||
existing_edge.get("name") == name
|
||||
and existing_edge.get("src_id") == src_id
|
||||
and existing_edge.get("dst_id") == dst_id
|
||||
for existing_edge in rels_list
|
||||
):
|
||||
rels_list.append(
|
||||
{
|
||||
"src_id": src_id,
|
||||
"dst_id": dst_id,
|
||||
"name": name,
|
||||
"properties": edge_properties,
|
||||
}
|
||||
)
|
||||
|
||||
def process_path(path: graph.Path):
|
||||
for rel in path.relationships:
|
||||
process_relationship(rel)
|
||||
|
||||
def process_other(value):
|
||||
if not any(
|
||||
existing_node.get("id") == "json_node" for existing_node in nodes_list
|
||||
):
|
||||
nodes_list.append(
|
||||
{
|
||||
"id": "json_node",
|
||||
"name": "json_node",
|
||||
"properties": {"description": value},
|
||||
}
|
||||
)
|
||||
|
||||
for record in data:
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
process_node(value)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
process_relationship(value)
|
||||
elif isinstance(value, graph.Path):
|
||||
process_path(value)
|
||||
else:
|
||||
process_other(value)
|
||||
nodes = [
|
||||
Vertex(node["id"], node["name"], **node["properties"])
|
||||
for node in nodes_list
|
||||
]
|
||||
rels = [
|
||||
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
|
||||
for edge in rels_list
|
||||
]
|
||||
return {"nodes": nodes, "edges": rels}
|
||||
|
||||
def get_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._config
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
query = (
|
||||
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.conn.run(query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
return value.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
subj_escaped = escape_quotes(subj)
|
||||
rel_escaped = escape_quotes(rel)
|
||||
obj_escaped = escape_quotes(obj)
|
||||
|
||||
node_query = f"""CALL db.upsertVertex(
|
||||
'{self._vertex_type}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{self._edge_type}',
|
||||
{{type:"{self._vertex_type}",key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def insert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
if value is not None:
|
||||
return value.replace("'", "").replace('"', "")
|
||||
|
||||
nodes: Iterator[Vertex] = graph.vertices()
|
||||
edges: Iterator[Edge] = graph.edges()
|
||||
node_list = []
|
||||
edge_list = []
|
||||
|
||||
def parser(node_list):
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in node_list
|
||||
]
|
||||
return f"""{', '.join(formatted_nodes)}"""
|
||||
|
||||
for node in nodes:
|
||||
node_list.append(
|
||||
{
|
||||
"id": escape_quotes(node.vid),
|
||||
"name": escape_quotes(node.name),
|
||||
"description": escape_quotes(node.get_prop("description")) or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
)
|
||||
node_query = (
|
||||
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
|
||||
)
|
||||
for edge in edges:
|
||||
edge_list.append(
|
||||
{
|
||||
"sid": escape_quotes(edge.sid),
|
||||
"tid": escape_quotes(edge.tid),
|
||||
"id": escape_quotes(edge.name),
|
||||
"name": escape_quotes(edge.name),
|
||||
"description": escape_quotes(edge.get_prop("description")),
|
||||
}
|
||||
)
|
||||
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
"{self._edge_type}",
|
||||
{{type:"{self._vertex_type}", key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{parser(edge_list)}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.conn.run(gql)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.conn.delete_graph(self._graph_name)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
|
||||
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
|
||||
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.conn.run(query=del_query)
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
raise Exception("limit must be set")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
all_graph.append_edge(edge)
|
||||
return all_graph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if fan is not None:
|
||||
raise ValueError("Fan functionality is not supported at this time.")
|
||||
else:
|
||||
depth_string = f"1..{depth}"
|
||||
if depth is None:
|
||||
depth_string = ".."
|
||||
|
||||
limit_string = f"LIMIT {limit}"
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{self._vertex_type})"
|
||||
f"{rel}(m:{self._vertex_type}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
|
||||
def query(self, query: str, **args) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
result = self.conn.run(query=query)
|
||||
white_list = args.get("white_list", [])
|
||||
graph = self._format_query_data(result, white_list)
|
||||
mg = MemoryGraph()
|
||||
for vertex in graph["nodes"]:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in graph["edges"]:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
for record in self.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
rel_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
src_id = rel_nodes[0]._properties["id"]
|
||||
dst_id = rel_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(src_id, dst_id, name=prop_id, description=description)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
if i < len(rels):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
formatted_path[i]["id"],
|
||||
formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex("json_node", name="json_node", description=value)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
def _escape_quotes(self, value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
if value is not None:
|
||||
return value.replace("'", "").replace('"', "")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Knowledge graph base class."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
@@ -27,6 +28,6 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get graph data."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete document by ids."""
|
||||
raise Exception("Delete document not supported by knowledge graph")
|
||||
|
@@ -1,11 +1,19 @@
|
||||
"""Define Classes about Community."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import AsyncGenerator, Iterator, List, Optional
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +32,7 @@ class CommunityTree:
|
||||
"""Represents a community tree."""
|
||||
|
||||
|
||||
class CommunityStoreAdapter(ABC):
|
||||
class GraphStoreAdapter(ABC):
|
||||
"""Community Store Adapter."""
|
||||
|
||||
def __init__(self, graph_store: GraphStoreBase):
|
||||
@@ -44,6 +52,113 @@ class CommunityStoreAdapter(ABC):
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_config(self):
|
||||
"""Get config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get vertex type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get document vertex."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entity."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
):
|
||||
"""Upsert edge."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunk."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Insert triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Insert graph."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_document(self, chunk_id: str) -> None:
|
||||
"""Delete document in graph store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self) -> None:
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
"""Create graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph_label(self) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self) -> None:
|
||||
"""Truncate graph."""
|
||||
|
||||
@abstractmethod
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
|
||||
|
||||
class CommunityMetastore(ABC):
|
||||
"""Community metastore class."""
|
||||
|
@@ -4,10 +4,7 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.community_metastore import (
|
||||
BuiltinCommunityMetastore,
|
||||
)
|
||||
@@ -21,23 +18,23 @@ class CommunityStore:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_store_adapter: CommunityStoreAdapter,
|
||||
graph_store_adapter: GraphStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._community_store_adapter = community_store_adapter
|
||||
self._graph_store_adapter = graph_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._community_store_adapter.discover_communities()
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._community_store_adapter.get_community(community_id)
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
@@ -65,7 +62,7 @@ class CommunityStore:
|
||||
self._community_summarizer.truncate()
|
||||
|
||||
logger.info("Truncate graph")
|
||||
self._community_store_adapter.graph_store.truncate()
|
||||
self._graph_store_adapter.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community store."""
|
||||
@@ -76,4 +73,4 @@ class CommunityStore:
|
||||
self._community_summarizer.drop()
|
||||
|
||||
logger.info("Remove graph")
|
||||
self._community_store_adapter.graph_store.drop()
|
||||
self._graph_store_adapter.drop()
|
||||
|
@@ -1,28 +1,29 @@
|
||||
"""CommunityStoreAdapter factory."""
|
||||
"""GraphStoreAdapter factory."""
|
||||
|
||||
import logging
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
|
||||
TuGraphCommunityStoreAdapter,
|
||||
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
|
||||
TuGraphStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStoreAdapterFactory:
|
||||
class GraphStoreAdapterFactory:
|
||||
"""Factory for community store adapter."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
|
||||
"""Create a CommunityStoreAdapter instance.
|
||||
def create(graph_store: GraphStoreBase) -> GraphStoreAdapter:
|
||||
"""Create a GraphStoreAdapter instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
"""
|
||||
if isinstance(graph_store, TuGraphStore):
|
||||
return TuGraphCommunityStoreAdapter(graph_store)
|
||||
return TuGraphStoreAdapter(graph_store)
|
||||
else:
|
||||
raise Exception(
|
||||
"create community store adapter for %s failed",
|
||||
|
@@ -0,0 +1,186 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""MemGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, enable_summary: bool = False):
|
||||
"""Initialize MemGraph Community Store Adapter."""
|
||||
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
self._enable_summary = enable_summary
|
||||
|
||||
super().__init__(self._graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self._graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
pass
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
pass
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
# raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
return ""
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
# raise NotImplementedError("Memory graph store does not have edge type")
|
||||
return ""
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
subgraph = self.explore([subj], direct=Direction.OUT, depth=1)
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
raise NotImplementedError("Memory graph store does not have document vertex")
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
return json.dumps(self._graph_store._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
return self._graph_store._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph_store._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
pass
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
pass
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
pass
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
pass
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
self._graph_store._graph.append_edge(Edge(subj, obj, rel))
|
||||
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
for vertex in graph.vertices():
|
||||
self._graph_store._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph_store._graph.append_edge(edge)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
pass
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
self._graph_store._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self._graph_store._graph = None
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
pass
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
self._graph_store._graph.truncate()
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
pass
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int | None = None,
|
||||
fan: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
pass
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
pass
|
@@ -1,52 +0,0 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self._graph_store.query(
|
||||
"CALL db.plugin.callPlugin"
|
||||
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
|
||||
f"[r:{self._graph_store.get_edge_type()}]-"
|
||||
f"(m:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self._graph_store.aquery(query)
|
||||
all_edge_graph = self._graph_store.aquery(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_QUERY_LIMIT = 1000
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, graph_store: TuGraphStore):
|
||||
"""Initialize TuGraph Community Store Adapter."""
|
||||
super().__init__(graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self.graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self.query(
|
||||
"CALL db.plugin.callPlugin('CPP',"
|
||||
"'leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self.get_vertex_type()}) WHERE n._community_id = "
|
||||
f"'{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self.get_vertex_type()})-"
|
||||
f"[r:{self.get_edge_type()}]-"
|
||||
f"(m:{self.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self.query(query)
|
||||
all_edge_graph = self.query(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
||||
|
||||
@property
|
||||
def graph_store(self) -> TuGraphStore:
|
||||
"""Get the graph store."""
|
||||
return self._graph_store
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self.graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return GraphElemType.ENTITY.value
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return GraphElemType.RELATION.value
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
triplet_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value})-[r]->(n2:"
|
||||
f"{GraphElemType.ENTITY.value}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.graph_store.conn.run(triplet_query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
gql = f"""MATCH (n) WHERE n.id = {doc_name} RETURN n"""
|
||||
graph = self.query(gql)
|
||||
vertex = graph.get_vertex(doc_name)
|
||||
return vertex
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.graph_store.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
limit = self.MAX_QUERY_LIMIT
|
||||
if limit <= 0:
|
||||
raise ValueError("Limit must be greater than 0.")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
full_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
full_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
full_graph.append_edge(edge)
|
||||
return full_graph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
entity_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(entity.vid),
|
||||
"name": self.graph_store._escape_quotes(entity.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
entity.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
for entity in entities
|
||||
]
|
||||
entity_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.ENTITY.value}", '
|
||||
f"[{self._parser(entity_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=entity_query)
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
edge_list = [
|
||||
{
|
||||
"sid": self.graph_store._escape_quotes(edge.sid),
|
||||
"tid": self.graph_store._escape_quotes(edge.tid),
|
||||
"id": self.graph_store._escape_quotes(edge.name),
|
||||
"name": self.graph_store._escape_quotes(edge.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
edge.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
|
||||
or "",
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
relation_query = f"""CALL db.upsertEdge("{edge_type}",
|
||||
{{type:"{src_type}", key:"sid"}},
|
||||
{{type:"{dst_type}", key:"tid"}},
|
||||
[{self._parser(edge_list)}])"""
|
||||
self.graph_store.conn.run(query=relation_query)
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(chunk.vid),
|
||||
"name": self.graph_store._escape_quotes(chunk.name),
|
||||
"content": self.graph_store._escape_quotes(chunk.get_prop("content")),
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
chunk_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.CHUNK.value}", '
|
||||
f"[{self._parser(chunk_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=chunk_query)
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
document_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(document.vid),
|
||||
"name": self.graph_store._escape_quotes(document.name),
|
||||
"content": self.graph_store._escape_quotes(document.get_prop("content"))
|
||||
or "",
|
||||
}
|
||||
for document in documents
|
||||
]
|
||||
document_query = (
|
||||
"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.DOCUMENT.value}", '
|
||||
f"[{self._parser(document_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=document_query)
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
|
||||
rel_escaped = rel.replace("'", "\\'").replace('"', '\\"')
|
||||
obj_escaped = obj.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
vertex_query = f"""CALL db.upsertVertex(
|
||||
'{GraphElemType.ENTITY.value}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{GraphElemType.RELATION.value}',
|
||||
{{type:"{GraphElemType.ENTITY.value}",key:"sid"}},
|
||||
{{type:"{GraphElemType.ENTITY.value}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
|
||||
self.graph_store.conn.run(query=vertex_query)
|
||||
self.graph_store.conn.run(query=edge_query)
|
||||
|
||||
def upsert_graph(self, graph: MemoryGraph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
# Get the iterators of all the vertices and the edges from the graph
|
||||
documents: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type")
|
||||
== GraphElemType.DOCUMENT.value
|
||||
)
|
||||
chunks: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.CHUNK.value
|
||||
)
|
||||
entities: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.ENTITY.value
|
||||
)
|
||||
doc_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.DOCUMENT_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_entity: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_ENTITY.value
|
||||
)
|
||||
chunk_next_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_NEXT_CHUNK.value
|
||||
)
|
||||
relation: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type") == GraphElemType.RELATION.value
|
||||
)
|
||||
|
||||
# Upsert the vertices and the edges to the graph store
|
||||
self.upsert_entities(entities)
|
||||
self.upsert_chunks(chunks)
|
||||
self.upsert_documents(documents)
|
||||
self.upsert_edge(
|
||||
doc_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.DOCUMENT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_entity,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_next_chunk,
|
||||
GraphElemType.NEXT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
relation,
|
||||
GraphElemType.RELATION.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
|
||||
del_chunk_gql = (
|
||||
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
||||
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
||||
)
|
||||
del_relation_gql = (
|
||||
f"MATCH(m:{GraphElemType.ENTITY.value})-[r:"
|
||||
f"{GraphElemType.RELATION.value}]-(n:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE r._chunk_id IN {chunkids_list} DELETE r"
|
||||
)
|
||||
delete_only_vertex = "MATCH (n) WHERE NOT EXISTS((n)-[]-()) DELETE n"
|
||||
self.graph_store.conn.run(del_chunk_gql)
|
||||
self.graph_store.conn.run(del_relation_gql)
|
||||
self.graph_store.conn.run(delete_only_vertex)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value} {{id:'{sub}'}})"
|
||||
f"-[r:{GraphElemType.RELATION.value} {{id:'{rel}'}}]->"
|
||||
f"(n2:{GraphElemType.ENTITY.value} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.graph_store.conn.run(query=del_query)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.graph_store.conn.delete_graph(self.get_graph_config().name)
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
self.graph_store.conn.create_graph(graph_name=graph_name)
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
index: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, str | bool]:
|
||||
"""Format the property for TuGraph.
|
||||
|
||||
Args:
|
||||
name: The name of the property.
|
||||
type: The type of the property.
|
||||
optional: The optional of the property.
|
||||
index: The index of the property.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The formatted property.
|
||||
"""
|
||||
property: Dict[str, str | bool] = {
|
||||
"name": name,
|
||||
"type": type,
|
||||
"optional": optional,
|
||||
}
|
||||
|
||||
if index is not None:
|
||||
property["index"] = index
|
||||
|
||||
# Add any additional keyword arguments to the property dictionary
|
||||
property.update(kwargs)
|
||||
return property
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
)
|
||||
|
||||
if self.graph_store._enable_summary:
|
||||
self.graph_store._upload_plugin()
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
graph_elem_type: GraphElemType,
|
||||
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
if graph_elem_type.is_vertex(): # vertex
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
"primary": "id",
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
|
||||
)
|
||||
else: # edge
|
||||
|
||||
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
|
||||
"""Define the edge direction.
|
||||
|
||||
`include` edge: document -> chunk, chunk -> entity
|
||||
`next` edge: chunk -> chunk
|
||||
`relation` edge: entity -> entity
|
||||
"""
|
||||
if graph_elem_type.is_vertex():
|
||||
raise ValueError("The graph element type must be an edge.")
|
||||
if graph_elem_type == GraphElemType.INCLUDE:
|
||||
return [
|
||||
[GraphElemType.DOCUMENT.value, GraphElemType.CHUNK.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.ENTITY.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value],
|
||||
]
|
||||
elif graph_elem_type == GraphElemType.NEXT:
|
||||
return [[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value]]
|
||||
elif graph_elem_type == GraphElemType.RELATION:
|
||||
return [[GraphElemType.ENTITY.value, GraphElemType.ENTITY.value]]
|
||||
else:
|
||||
raise ValueError("Invalid graph element type.")
|
||||
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
"constraints": edge_direction(graph_elem_type),
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
|
||||
)
|
||||
|
||||
# Make sure the graph label is identical
|
||||
try:
|
||||
self.graph_store.conn.run(
|
||||
gql_check_exist
|
||||
) # if not exist, qurying raises an exception
|
||||
except Exception:
|
||||
self.graph_store.conn.run(gql) # create the graph label
|
||||
return
|
||||
|
||||
logger.info(f"Graph label {graph_elem_type.value} already exists.")
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.graph_store.conn.run(gql)
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
|
||||
|
||||
if graph_elem_type.is_vertex():
|
||||
return graph_elem_type in vertex_tables
|
||||
else:
|
||||
return graph_elem_type in edge_tables
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL:
|
||||
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL
|
||||
# depth_string = ".."
|
||||
depth = self.MAX_HIERARCHY_LEVEL
|
||||
depth_string = f"1..{depth}"
|
||||
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
if search_scope == "knowledge_graph":
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
|
||||
f"{rel}(m:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
else:
|
||||
graph = MemoryGraph()
|
||||
|
||||
for sub in subs:
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
|
||||
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' "
|
||||
f"RETURN p {limit_string}"
|
||||
) # if it contains the subjects
|
||||
result = self.query(query)
|
||||
for vertex in result.vertices():
|
||||
graph.upsert_vertex(vertex)
|
||||
for edge in result.edges():
|
||||
graph.append_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph.
|
||||
|
||||
white_list: List[str] = kwargs.get("white_list", []), which contains the white
|
||||
list of properties and filters the properties that are not in the white list.
|
||||
"""
|
||||
query_result = self.graph_store.conn.run(query=query)
|
||||
white_list: List[str] = kwargs.get(
|
||||
"white_list",
|
||||
[
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"_document_id",
|
||||
"_chunk_id",
|
||||
"_community_id",
|
||||
],
|
||||
)
|
||||
vertices, edges = self._get_nodes_edges_from_queried_data(
|
||||
query_result, white_list
|
||||
)
|
||||
mg = MemoryGraph()
|
||||
for vertex in vertices:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in edges:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
async for record in self.graph_store.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(vid=node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
edge_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
assert edge_nodes and edge_nodes[0] and edge_nodes[1]
|
||||
src_id = edge_nodes[0]._properties["id"]
|
||||
dst_id = edge_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(
|
||||
sid=src_id, tid=dst_id, name=prop_id, description=description
|
||||
)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append({
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
})
|
||||
if i < len(rels):
|
||||
formatted_path.append({
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
})
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
vid=formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
sid=formatted_path[i]["id"],
|
||||
tid=formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex(
|
||||
vid="json_node", name="json_node", description=value
|
||||
)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
|
||||
def _get_nodes_edges_from_queried_data(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
white_prop_list: List[str],
|
||||
) -> Tuple[List[Vertex], List[Edge]]:
|
||||
"""Format the query data.
|
||||
|
||||
Args:
|
||||
data: The data to be formatted.
|
||||
white_prop_list: The white list of properties.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Vertex], List[Edge]]: The formatted vertices and edges.
|
||||
"""
|
||||
vertex_list: List[Vertex] = []
|
||||
edge_list: List[Edge] = []
|
||||
|
||||
# Remove id, src_id, dst_id and name from the white list
|
||||
# to avoid duplication in the initialisation of the vertex and edge
|
||||
_white_list = [
|
||||
prop
|
||||
for prop in white_prop_list
|
||||
if prop not in ["id", "src_id", "dst_id", "name"]
|
||||
]
|
||||
|
||||
from neo4j import graph
|
||||
|
||||
def filter_properties(
|
||||
properties: dict[str, Any], white_list: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter the properties.
|
||||
|
||||
It will remove the properties that are not in the white list.
|
||||
The expected propertities are:
|
||||
entity_properties = ["id", "name", "description", "_document_id",
|
||||
"_chunk_id", "_community_id"]
|
||||
edge_properties = ["id", "name", "description", "_chunk_id"]
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
# Parse the data to nodes and relationships
|
||||
for record in data:
|
||||
for value in record.values():
|
||||
if isinstance(value, graph.Node):
|
||||
assert value._properties.get("id")
|
||||
vertex = Vertex(
|
||||
vid=value._properties.get("id", ""),
|
||||
name=value._properties.get("name"),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
# TODO: Do we really need to check it every time?
|
||||
vertex_list.append(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
for node in value.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert value.nodes and value.nodes[0] and value.nodes[1]
|
||||
edge = Edge(
|
||||
sid=value.nodes[0]._properties.get("id", ""),
|
||||
tid=value.nodes[1]._properties.get("id", ""),
|
||||
name=value._properties.get("name", ""),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
for rel in value.relationships:
|
||||
for node in rel.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert rel.nodes and rel.nodes[0] and rel.nodes[1]
|
||||
edge = Edge(
|
||||
sid=rel.nodes[0]._properties.get("id", ""),
|
||||
tid=rel.nodes[1]._properties.get("id", ""),
|
||||
name=rel._properties.get("name", ""),
|
||||
**filter_properties(rel._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
|
||||
else: # json_node
|
||||
vertex = Vertex(
|
||||
vid="json_node",
|
||||
name="json_node",
|
||||
**filter_properties({"description": value}, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
return vertex_list, edge_list
|
||||
|
||||
def _parser(self, entity_list: List[Dict[str, Any]]) -> str:
|
||||
"""Parse entities to string."""
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in entity_list
|
||||
]
|
||||
return f"""{", ".join(formatted_nodes)}"""
|
@@ -2,14 +2,16 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
|
||||
from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
|
||||
from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory
|
||||
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
@@ -27,7 +29,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
vector_store_type: str = Field(
|
||||
default="Chroma", description="The type of vector store."
|
||||
default="Chroma",
|
||||
description="The type of vector store.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -36,7 +39,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
"The password of vector store, "
|
||||
"if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
@@ -120,7 +124,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
cfg.score_threshold = self._community_score_threshold
|
||||
|
||||
self._community_store = CommunityStore(
|
||||
CommunityStoreAdapterFactory.create(self._graph_store),
|
||||
GraphStoreAdapterFactory.create(self._graph_store),
|
||||
CommunitySummarizer(self._llm_client, self._model_name),
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
@@ -135,21 +139,165 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph."""
|
||||
# todo add doc node
|
||||
for chunk in chunks:
|
||||
# todo add chunk node
|
||||
# todo add relation doc-chunk
|
||||
data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph
|
||||
graph_of_all = MemoryGraph()
|
||||
|
||||
# extract graphs and save
|
||||
graphs = await self._graph_extractor.extract(chunk.content)
|
||||
for graph in graphs:
|
||||
self._graph_store.insert_graph(graph)
|
||||
# Support graph search by the document and the chunks
|
||||
if self._graph_store.get_config().enable_document_graph:
|
||||
doc_vid = str(uuid.uuid4())
|
||||
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
|
||||
for chunk_index, chunk in enumerate(data_list):
|
||||
if chunk["parent_id"] != "document":
|
||||
# chunk -> include -> chunk
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=chunk["parent_id"],
|
||||
src_name=chunk["parent_title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
|
||||
)
|
||||
else:
|
||||
# document -> include -> chunk
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=doc_vid,
|
||||
src_name=doc_name,
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.DOCUMENT.value,
|
||||
"content": "",
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
|
||||
)
|
||||
|
||||
# chunk -> next -> chunk
|
||||
if chunk_index >= 1:
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=data_list[chunk_index - 1]["id"],
|
||||
src_name=data_list[chunk_index - 1]["title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": data_list[chunk_index - 1]["content"],
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.NEXT.value,
|
||||
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
|
||||
)
|
||||
|
||||
# Support knowledge graph search by the entities and the relationships
|
||||
if self._graph_store.get_config().enable_triplet_graph:
|
||||
for chunk_index, chunk in enumerate(data_list):
|
||||
# TODO: Use asyncio to extract graph to accelerate the process
|
||||
# (attention to the CAP of the graph db)
|
||||
|
||||
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
|
||||
chunk["content"]
|
||||
)
|
||||
|
||||
for graph in graphs:
|
||||
graph_of_all.upsert_graph(graph)
|
||||
|
||||
# chunk -> include -> entity
|
||||
if self._graph_store.get_config().enable_document_graph:
|
||||
for vertex in graph.vertices():
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=chunk["id"],
|
||||
src_name=chunk["title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
dst_vid=vertex.vid,
|
||||
dst_name=vertex.name,
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.ENTITY.value,
|
||||
"description": vertex.props.get("description", ""),
|
||||
}, # note: description is only used for the entity
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
|
||||
)
|
||||
|
||||
self._graph_store_apdater.upsert_graph(graph_of_all)
|
||||
|
||||
# use asyncio.gather
|
||||
# tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks]
|
||||
# results = await asyncio.gather(*tasks)
|
||||
# for result in results:
|
||||
# self._graph_store_apdater.upsert_graph(result[0])
|
||||
|
||||
# build communities and save
|
||||
|
||||
await self._community_store.build_communities()
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
def _parse_chunks(slef, chunks: List[Chunk]):
|
||||
"""Parse the chunks by anlyzing the markdown chunks."""
|
||||
# TODO: Need to refact.
|
||||
data = []
|
||||
for chunk_index, chunk in enumerate(chunks):
|
||||
parent = None
|
||||
directory_keys = list(chunk.metadata.keys())[:-1]
|
||||
parent_level = directory_keys[-2] if len(directory_keys) > 1 else None
|
||||
current_level = directory_keys[-1] if directory_keys else "Header0"
|
||||
|
||||
chunk_data = {
|
||||
"id": chunk.chunk_id,
|
||||
"title": chunk.metadata.get(current_level, "none_header_chunk"),
|
||||
"directory_keys": directory_keys,
|
||||
"level": current_level,
|
||||
"content": chunk.content,
|
||||
"parent_id": None,
|
||||
"parent_title": None,
|
||||
"type": "chunk",
|
||||
"chunk_index": chunk_index,
|
||||
}
|
||||
|
||||
# Find the parent chunk
|
||||
if parent_level:
|
||||
for parent_direct in reversed(directory_keys[:-1]):
|
||||
parent_titile = chunk.metadata.get(parent_direct, None)
|
||||
for n in range(chunk_index - 1, -1, -1):
|
||||
metadata = chunks[n].metadata
|
||||
keys = list(metadata.keys())[:-1]
|
||||
if (
|
||||
metadata
|
||||
and parent_direct == keys[-1]
|
||||
and parent_titile == metadata.get(parent_direct)
|
||||
):
|
||||
parent = chunks[n]
|
||||
chunk_data["parent_id"] = parent.chunk_id
|
||||
chunk_data["parent_title"] = parent_titile
|
||||
break
|
||||
if chunk_index - n > len(directory_keys):
|
||||
break
|
||||
if chunk_data["parent_id"]:
|
||||
break
|
||||
|
||||
if not chunk_data["parent_id"]:
|
||||
chunk_data["parent_id"] = "document"
|
||||
data.append(chunk_data)
|
||||
return data
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
@@ -158,7 +306,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve relevant community summaries."""
|
||||
# global search: retrieve relevant community summaries
|
||||
# Global search: retrieve relevant community summaries
|
||||
communities = await self._community_store.search_communities(text)
|
||||
summaries = [
|
||||
f"Section {i + 1}:\n{community.summary}"
|
||||
@@ -166,16 +314,53 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
]
|
||||
context = "\n".join(summaries) if summaries else ""
|
||||
|
||||
# local search: extract keywords and explore subgraph
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||
|
||||
if not summaries and not subgraph:
|
||||
# Local search: extract keywords and explore subgraph
|
||||
subgraph = MemoryGraph()
|
||||
subgraph_for_doc = MemoryGraph()
|
||||
|
||||
enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph
|
||||
enable_document_graph = self._graph_store.get_config().enable_document_graph
|
||||
|
||||
if enable_triplet_graph:
|
||||
subgraph: MemoryGraph = self._graph_store_apdater.explore(
|
||||
subs=keywords, limit=10, search_scope="knowledge_graph"
|
||||
)
|
||||
|
||||
if enable_document_graph:
|
||||
keywords_for_document_graph = keywords
|
||||
for vertex in subgraph.vertices():
|
||||
keywords_for_document_graph.append(vertex.name)
|
||||
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords_for_document_graph,
|
||||
limit=5,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
else:
|
||||
if enable_document_graph:
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords,
|
||||
limit=10,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
|
||||
knowledge_graph_str = subgraph.format()
|
||||
knowledge_graph_for_doc_str = subgraph_for_doc.format()
|
||||
|
||||
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
|
||||
|
||||
if not (summaries or knowledge_graph_str or knowledge_graph_for_doc_str):
|
||||
return []
|
||||
|
||||
# merge search results into context
|
||||
content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph)
|
||||
content = HYBRID_SEARCH_PT_CN.format(
|
||||
context=context,
|
||||
knowledge_graph=knowledge_graph_str,
|
||||
knowledge_graph_for_doc=knowledge_graph_for_doc_str,
|
||||
)
|
||||
logger.info(f"Final GraphRAG queried prompt:\n{content}")
|
||||
return [Chunk(content=content)]
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
@@ -200,174 +385,179 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
self._graph_extractor.drop()
|
||||
|
||||
|
||||
HYBRID_SEARCH_PT_CN = (
|
||||
"## 角色\n"
|
||||
"你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,"
|
||||
"准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。"
|
||||
"\n"
|
||||
"## 技能\n"
|
||||
"### 技能 1: 上下文理解\n"
|
||||
"- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n"
|
||||
"- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n"
|
||||
"- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。"
|
||||
"### 技能 2: 知识图谱理解\n"
|
||||
"- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息"
|
||||
"和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n"
|
||||
"```"
|
||||
"* 实体信息格式:\n"
|
||||
"- (实体名)\n"
|
||||
"- (实体名:实体描述)\n"
|
||||
"- (实体名:实体属性表)\n"
|
||||
"- (文本块ID:文档块内容)\n"
|
||||
"- (目录ID:目录名)\n"
|
||||
"- (文档ID:文档名称)\n"
|
||||
"\n"
|
||||
"* 关系信息的格式:\n"
|
||||
"- (来源实体名)-[关系名]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系描述]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n"
|
||||
"- (文本块实体)-[包含]->(实体名)\n"
|
||||
"- (目录ID)-[包含]->(文本块实体)\n"
|
||||
"- (目录ID)-[包含]->(子目录ID)\n"
|
||||
"- (文档ID)-[包含]->(文本块实体)\n"
|
||||
"- (文档ID)-[包含]->(目录ID)\n"
|
||||
"```"
|
||||
"- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。"
|
||||
"- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n"
|
||||
"\n"
|
||||
"## 约束条件\n"
|
||||
"- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n"
|
||||
"- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。"
|
||||
"- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n"
|
||||
"- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n"
|
||||
"- 避免使用停用词和过于常见的词汇。\n"
|
||||
"\n"
|
||||
"## 参考案例\n"
|
||||
"```\n"
|
||||
"[上下文]:\n"
|
||||
"Section 1:\n"
|
||||
"菲尔・贾伯的大儿子叫雅各布・贾伯。\n"
|
||||
"Section 2:\n"
|
||||
"菲尔・贾伯的小儿子叫比尔・贾伯。\n"
|
||||
"[知识图谱]:\n"
|
||||
"Entities:\n"
|
||||
"(菲尔・贾伯#菲尔兹咖啡创始人)\n"
|
||||
"(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n"
|
||||
"(雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(美国多地#菲尔兹咖啡的扩展地区)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n"
|
||||
"(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n"
|
||||
"(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n"
|
||||
"(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n"
|
||||
"\n"
|
||||
"[上下文]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[知识图谱]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
||||
HYBRID_SEARCH_PT_CN = """## 角色
|
||||
你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,
|
||||
准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。
|
||||
|
||||
HYBRID_SEARCH_PT_EN = (
|
||||
"## Role\n"
|
||||
"You excel at combining the information provided in the [Context] with "
|
||||
"information from the [KnowledgeGraph] to accurately and appropriately "
|
||||
"answer user questions, ensuring that you do not output information "
|
||||
"unrelated to the context and knowledge graph.\n"
|
||||
"\n"
|
||||
"## Skills\n"
|
||||
"### Skill 1: Context Understanding\n"
|
||||
"- Accurately understand the information provided in the [Context], "
|
||||
"which may be divided into several sections.\n"
|
||||
"- Each section in the context will start with [Section] "
|
||||
"and may be numbered as needed.\n"
|
||||
"- The context provides a summary description most relevant to the user’s "
|
||||
"question, and it should be used wisely."
|
||||
"### Skill 2: Knowledge Graph Understanding\n"
|
||||
"- Accurately identify entity information in the [Entities:] section and "
|
||||
"relationship information in the [Relationships:] section "
|
||||
"of the [KnowledgeGraph]. The general format for entity "
|
||||
"and relationship information is:\n"
|
||||
"```"
|
||||
"* Entity Information Format:\n"
|
||||
"- (entity_name)\n"
|
||||
"- (entity_name: entity_description)\n"
|
||||
"- (entity_name: entity_property_map)\n"
|
||||
"- (chunk_id: chunk_content)\n"
|
||||
"- (catalog_id: catalog_name)\n"
|
||||
"- (document_id: document_name)\n"
|
||||
"\n"
|
||||
"* Relationship Information Format:\n"
|
||||
"- (source_entity_name)-[relationship_name]->(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_description]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_property_map]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (chunk_id)-[Contains]->(entity_name)\n"
|
||||
"- (catalog_id)-[Contains]->(chunk_id)\n"
|
||||
"- (catalog_id)-[Contains]->(sub_catalog_id)\n"
|
||||
"- (document_id)-[Contains]->(chunk_id)\n"
|
||||
"- (document_id)-[Contains]->(catalog_id)\n"
|
||||
"```"
|
||||
"- Correctly associate entity names/IDs in the relationship information "
|
||||
"with entity information to restore the graph structure."
|
||||
"- Use the information expressed by the graph structure as detailed "
|
||||
"context for the user's query to assist in generating better answers.\n"
|
||||
"\n"
|
||||
"## Constraints\n"
|
||||
"- Don't describe your thought process in the answer, provide the answer "
|
||||
"to the user's question directly without generating irrelevant information."
|
||||
"- If the [KnowledgeGraph] does not provide information, you should answer "
|
||||
"the question based on the information provided in the [Context]."
|
||||
"- Ensure to write in the third person, responding to questions from "
|
||||
"an objective perspective based on the information combined from the "
|
||||
"[Context] and the [KnowledgeGraph].\n"
|
||||
"- If the provided information is contradictory, resolve the "
|
||||
"contradictions and provide a single, coherent description.\n"
|
||||
"- Avoid using stop words and overly common vocabulary.\n"
|
||||
"\n"
|
||||
"## Reference Example\n"
|
||||
"```\n"
|
||||
"[Context]:\n"
|
||||
"Section 1:\n"
|
||||
"Phil Schiller's eldest son is Jacob Schiller.\n"
|
||||
"Section 2:\n"
|
||||
"Phil Schiller's youngest son is Bill Schiller.\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"Entities:\n"
|
||||
"(Phil Jaber#Founder of Philz Coffee)\n"
|
||||
"(Philz Coffee#Coffee brand founded in Berkeley, California)\n"
|
||||
"(Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Multiple locations in the USA#Expansion regions of Philz Coffee)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(Phil Jaber#Created#Philz Coffee"
|
||||
"#Founded in Berkeley, California in 1978)\n"
|
||||
"(Philz Coffee#Located in#Berkeley, California"
|
||||
"#Founding location of Philz Coffee)\n"
|
||||
"(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n"
|
||||
"(Philz Coffee#Expanded to#Multiple locations in the USA"
|
||||
"#Expansion regions of Philz Coffee)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"The following information from the [Context] and [KnowledgeGraph] can "
|
||||
"help you better answer user questions.\n"
|
||||
"\n"
|
||||
"[Context]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
||||
## 技能
|
||||
### 技能 1: 上下文理解
|
||||
- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。
|
||||
- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。
|
||||
- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。
|
||||
### 技能 2: 知识图谱理解
|
||||
- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:
|
||||
```
|
||||
* 实体信息格式:
|
||||
- (实体名)
|
||||
- (实体名:实体描述)
|
||||
- (实体名:实体属性表)
|
||||
- (文本块ID:文档块内容)
|
||||
- (目录ID:目录名)
|
||||
- (文档ID:文档名称)
|
||||
|
||||
* 关系信息的格式:
|
||||
- (来源实体名)-[关系名]->(目标实体名)
|
||||
- (来源实体名)-[关系名:关系描述]->(目标实体名)
|
||||
- (来源实体名)-[关系名:关系属性表]->(目标实体名)
|
||||
- (文本块实体)-[包含]->(实体名)
|
||||
- (目录ID)-[包含]->(文本块实体)
|
||||
- (目录ID)-[包含]->(子目录ID)
|
||||
- (文档ID)-[包含]->(文本块实体)
|
||||
- (文档ID)-[包含]->(目录ID)
|
||||
```
|
||||
- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。
|
||||
- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。
|
||||
|
||||
|
||||
## 约束条件
|
||||
- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。
|
||||
- 若[知识图谱]或者[知识库原文]没有提供信息,此时应根据[上下文]提供的信息回答问题。
|
||||
- 确保以第三人称书写,从客观角度结合[上下文]、[知识图谱]和[知识库原文]表达的信息回答问题。
|
||||
- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。
|
||||
- 避免使用停用词和过于常见的词汇。
|
||||
|
||||
## 参考案例
|
||||
```
|
||||
[上下文]:
|
||||
Section 1:
|
||||
菲尔・贾伯的大儿子叫雅各布・贾伯。
|
||||
Section 2:
|
||||
菲尔・贾伯的小儿子叫比尔・贾伯。
|
||||
|
||||
[知识图谱]:
|
||||
Entities:
|
||||
(菲尔・贾伯#菲尔兹咖啡创始人)
|
||||
(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)
|
||||
(雅各布・贾伯#菲尔・贾伯的儿子)
|
||||
(美国多地#菲尔兹咖啡的扩展地区)
|
||||
|
||||
Relationships:
|
||||
(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)
|
||||
(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)
|
||||
(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)
|
||||
(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)
|
||||
(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)
|
||||
|
||||
[知识库原文]:
|
||||
...
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
接下来的[上下文]、[知识图谱]和[知识库原文]的信息,可以帮助你回答更好地用户的问题。
|
||||
|
||||
[上下文]:
|
||||
{context}
|
||||
|
||||
[知识图谱]:
|
||||
{knowledge_graph}
|
||||
|
||||
[知识库原文]
|
||||
{knowledge_graph_for_doc}
|
||||
""" # noqa: E501
|
||||
|
||||
HYBRID_SEARCH_PT_EN = """## Role
|
||||
You excel at combining the information provided in the [Context] with
|
||||
information from the [KnowledgeGraph] to accurately and appropriately
|
||||
answer user questions, ensuring that you do not output information
|
||||
unrelated to the context and knowledge graph.
|
||||
|
||||
## Skills
|
||||
### Skill 1: Context Understanding
|
||||
- Accurately understand the information provided in the [Context],
|
||||
which may be divided into several sections.
|
||||
- Each section in the context will start with [Section]
|
||||
and may be numbered as needed.
|
||||
- The context provides a summary description most relevant to the user's
|
||||
question, and it should be used wisely.
|
||||
### Skill 2: Knowledge Graph Understanding
|
||||
- Accurately identify entity information in the [Entities:] section and
|
||||
relationship information in the [Relationships:] section
|
||||
of the [KnowledgeGraph]. The general format for entity
|
||||
and relationship information is:
|
||||
```
|
||||
* Entity Information Format:
|
||||
- (entity_name)
|
||||
- (entity_name: entity_description)
|
||||
- (entity_name: entity_property_map)
|
||||
- (chunk_id: chunk_content)
|
||||
- (catalog_id: catalog_name)
|
||||
- (document_id: document_name)
|
||||
|
||||
* Relationship Information Format:
|
||||
- (source_entity_name)-[relationship_name]->(target_entity_name)
|
||||
- (source_entity_name)-[relationship_name: relationship_description]->(target_entity_name)
|
||||
- (source_entity_name)-[relationship_name: relationship_property_map]->(target_entity_name)
|
||||
- (chunk_id)-[Contains]->(entity_name)
|
||||
- (catalog_id)-[Contains]->(chunk_id)
|
||||
- (catalog_id)-[Contains]->(sub_catalog_id)
|
||||
- (document_id)-[Contains]->(chunk_id)
|
||||
- (document_id)-[Contains]->(catalog_id)
|
||||
```
|
||||
- Correctly associate entity names/IDs in the relationship information
|
||||
with entity information to restore the graph structure.
|
||||
- Use the information expressed by the graph structure as detailed
|
||||
context for the user's query to assist in generating better answers.
|
||||
|
||||
## Constraints
|
||||
- Don't describe your thought process in the answer, provide the answer
|
||||
to the user's question directly without generating irrelevant information.
|
||||
- If the [KnowledgeGraph] or [Knowledge base original text] does not provide information, you should answer
|
||||
the question based on the information provided in the [Context].
|
||||
- Ensure to write in the third person, responding to questions from
|
||||
an objective perspective based on the information combined from the
|
||||
[Context], the [KnowledgeGraph] and the [Knowledge base original text].
|
||||
- If the provided information is contradictory, resolve the
|
||||
contradictions and provide a single, coherent description.
|
||||
- Avoid using stop words and overly common vocabulary.
|
||||
|
||||
## Reference Example
|
||||
```
|
||||
[Context]:
|
||||
Section 1:
|
||||
Phil Schiller's eldest son is Jacob Schiller.
|
||||
Section 2:
|
||||
Phil Schiller's youngest son is Bill Schiller.
|
||||
|
||||
[KnowledgeGraph]:
|
||||
Entities:
|
||||
(Phil Jaber#Founder of Philz Coffee)
|
||||
(Philz Coffee#Coffee brand founded in Berkeley, California)
|
||||
(Jacob Jaber#Son of Phil Jaber)
|
||||
(Multiple locations in the USA#Expansion regions of Philz Coffee)
|
||||
|
||||
Relationships:
|
||||
(Phil Jaber#Created#Philz Coffee#Founded in Berkeley, California in 1978)
|
||||
(Philz Coffee#Located in#Berkeley, California#Founding location of Philz Coffee)
|
||||
(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)
|
||||
(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)
|
||||
(Philz Coffee#Expanded to#Multiple locations in the USA#Expansion regions of Philz Coffee)
|
||||
|
||||
[Knowledge base original text]
|
||||
...
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
The following information from the [Context], [KnowledgeGraph] and [Knowledge base original text]
|
||||
can help you better answer user questions.
|
||||
|
||||
[Context]:
|
||||
{context}
|
||||
|
||||
[KnowledgeGraph]:
|
||||
{knowledge_graph}
|
||||
|
||||
[Knowledge base original text]
|
||||
{knowledge_graph_for_doc}
|
||||
""" # noqa: E501
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Knowledge graph class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
@@ -12,6 +13,8 @@ from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.factory import GraphStoreFactory
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
|
||||
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,9 +49,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
self._model_name = config.model_name
|
||||
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
|
||||
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
|
||||
self._graph_store = self.__init_graph_store(config)
|
||||
self._graph_store: GraphStoreBase = self.__init_graph_store(config)
|
||||
self._graph_store_apdater: GraphStoreAdapter = self.__init_graph_store_adapter()
|
||||
|
||||
def __init_graph_store(self, config) -> GraphStoreBase:
|
||||
def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase:
|
||||
def configure(cfg: GraphStoreConfig):
|
||||
cfg.name = config.name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
@@ -56,6 +60,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type
|
||||
return GraphStoreFactory.create(graph_store_type, configure)
|
||||
|
||||
def __init_graph_store_adapter(self):
|
||||
return GraphStoreAdapterFactory.create(self._graph_store)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
@@ -63,10 +70,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist triplets to graph store."""
|
||||
|
||||
async def process_chunk(chunk):
|
||||
async def process_chunk(chunk: Chunk):
|
||||
triplets = await self._triplet_extractor.extract(chunk.content)
|
||||
for triplet in triplets:
|
||||
self._graph_store.insert_triplet(*triplet)
|
||||
self._graph_store_apdater.insert_triplet(*triplet)
|
||||
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
|
||||
return chunk.chunk_id
|
||||
|
||||
@@ -89,7 +96,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
for chunk in chunks:
|
||||
triplets = await self._triplet_extractor.extract(chunk.content)
|
||||
for triplet in triplets:
|
||||
self._graph_store.insert_triplet(*triplet)
|
||||
self._graph_store_apdater.insert_triplet(*triplet)
|
||||
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
@@ -116,7 +123,8 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
# extract keywords and explore graph store
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
subgraph = self._graph_store_apdater.explore(keywords, limit=topk).format()
|
||||
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
|
||||
if not subgraph:
|
||||
@@ -147,12 +155,12 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Query graph."""
|
||||
return self._graph_store.get_full_graph(limit)
|
||||
return self._graph_store_apdater.get_full_graph(limit)
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate knowledge graph."""
|
||||
logger.info(f"Truncate graph {self._config.name}")
|
||||
self._graph_store.truncate()
|
||||
self._graph_store_apdater.truncate()
|
||||
|
||||
logger.info("Truncate keyword extractor")
|
||||
self._keyword_extractor.truncate()
|
||||
@@ -165,10 +173,15 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete vector name."""
|
||||
logger.info(f"Drop graph {index_name}")
|
||||
self._graph_store.drop()
|
||||
self._graph_store_apdater.drop()
|
||||
|
||||
logger.info("Drop keyword extractor")
|
||||
self._keyword_extractor.drop()
|
||||
|
||||
logger.info("Drop triplet extractor")
|
||||
self._triplet_extractor.drop()
|
||||
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete by ids."""
|
||||
self._graph_store_apdater.delete_document(chunk_ids=ids)
|
||||
return []
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""The database manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
Reference in New Issue
Block a user