mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.index.base import logger
|
||||
from dbgpt.rag.index.base import IndexStoreConfig, logger
|
||||
from dbgpt.storage.full_text.base import FullTextStoreBase
|
||||
from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
@@ -35,6 +35,7 @@ class ElasticDocumentStore(FullTextStoreBase):
|
||||
This similarity has the following options:
|
||||
"""
|
||||
super().__init__()
|
||||
self._es_config = es_config
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
self._es_config = es_config
|
||||
@@ -94,6 +95,10 @@ class ElasticDocumentStore(FullTextStoreBase):
|
||||
)
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
|
||||
def get_config(self) -> IndexStoreConfig:
|
||||
"""Get the es store config."""
|
||||
return self._es_config
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in elasticsearch.
|
||||
|
||||
|
@@ -2,11 +2,11 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.storage.full_text.base import FullTextStoreBase
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
|
||||
class OpenSearch(IndexStoreBase):
|
||||
class OpenSearch(FullTextStoreBase):
|
||||
"""OpenSearch index store."""
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Graph store base class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
@@ -23,15 +23,35 @@ class GraphStoreConfig(BaseModel):
|
||||
default=None,
|
||||
description="The embedding function of graph store, optional.",
|
||||
)
|
||||
summary_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable graph community summary or not.",
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreBase(ABC):
|
||||
"""Graph store base class."""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> GraphStoreConfig:
|
||||
"""Get the graph store config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Add triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
@@ -40,6 +60,10 @@ class GraphStoreBase(ABC):
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
@@ -66,3 +90,11 @@ class GraphStoreBase(ABC):
|
||||
@abstractmethod
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query."""
|
||||
|
||||
def aquery(self, query: str, **args) -> Graph:
|
||||
"""Async execute a query."""
|
||||
return self.query(query, **args)
|
||||
|
||||
@abstractmethod
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""Connector for vector store."""
|
||||
"""Graph store factory."""
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""Graph store base class."""
|
||||
"""Graph definition."""
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -24,9 +24,15 @@ class Direction(Enum):
|
||||
class Elem(ABC):
|
||||
"""Elem class."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""Initialize Elem."""
|
||||
self._props = {}
|
||||
self._name = name
|
||||
self._props: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the edge label."""
|
||||
return self._name or ""
|
||||
|
||||
@property
|
||||
def props(self) -> Dict[str, Any]:
|
||||
@@ -46,14 +52,17 @@ class Elem(ABC):
|
||||
self._props.pop(key, None)
|
||||
|
||||
def has_props(self, **props):
|
||||
"""Check if the element has the specified properties with the given values."""
|
||||
"""Check all key-value pairs exist."""
|
||||
return all(self._props.get(k) == v for k, v in props.items())
|
||||
|
||||
@abstractmethod
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
def format(self) -> str:
|
||||
"""Format properties into a string."""
|
||||
if len(self._props) == 1:
|
||||
return str(next(iter(self._props.values())))
|
||||
|
||||
formatted_props = [
|
||||
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
|
||||
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
|
||||
]
|
||||
return f"{{{';'.join(formatted_props)}}}"
|
||||
|
||||
@@ -61,9 +70,9 @@ class Elem(ABC):
|
||||
class Vertex(Elem):
|
||||
"""Vertex class."""
|
||||
|
||||
def __init__(self, vid: str, **props):
|
||||
def __init__(self, vid: str, name: Optional[str] = None, **props):
|
||||
"""Initialize Vertex."""
|
||||
super().__init__()
|
||||
super().__init__(name)
|
||||
self._vid = vid
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
@@ -73,26 +82,43 @@ class Vertex(Elem):
|
||||
"""Return the vertex ID."""
|
||||
return self._vid
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
"""Format vertex properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else self._vid
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"({label})"
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the vertex name."""
|
||||
return super().name or self._vid
|
||||
|
||||
def format(self, concise: bool = False):
|
||||
"""Format vertex into a string."""
|
||||
name = self._name or self._vid
|
||||
if concise:
|
||||
return f"({name})"
|
||||
|
||||
if self._props:
|
||||
return f"({name}:{super().format()})"
|
||||
else:
|
||||
return f"({label}:{props_str})"
|
||||
return f"({name})"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the vertex ID as its string representation."""
|
||||
return f"({self._vid})"
|
||||
|
||||
|
||||
class IdVertex(Vertex):
|
||||
"""IdVertex class."""
|
||||
|
||||
def __init__(self, vid: str):
|
||||
"""Initialize Idvertex."""
|
||||
super().__init__(vid)
|
||||
|
||||
|
||||
class Edge(Elem):
|
||||
"""Edge class."""
|
||||
|
||||
def __init__(self, sid: str, tid: str, **props):
|
||||
def __init__(self, sid: str, tid: str, name: str, **props):
|
||||
"""Initialize Edge."""
|
||||
super().__init__()
|
||||
assert name, "Edge name is required"
|
||||
|
||||
super().__init__(name)
|
||||
self._sid = sid
|
||||
self._tid = tid
|
||||
for k, v in props.items():
|
||||
@@ -117,23 +143,20 @@ class Edge(Elem):
|
||||
else:
|
||||
raise ValueError(f"Get nid of {vid} on {self} failed")
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
def format(self):
|
||||
"""Format the edge properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else ""
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"-[{label}]->" if label else "->"
|
||||
if self._props:
|
||||
return f"-[{self._name}:{super().format()}]->"
|
||||
else:
|
||||
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
|
||||
return f"-[{self._name}]->"
|
||||
|
||||
def triplet(self, label_key: str) -> Tuple[str, str, str]:
|
||||
def triplet(self) -> Tuple[str, str, str]:
|
||||
"""Return a triplet."""
|
||||
assert label_key, "label key is needed"
|
||||
return self._sid, str(self.get_prop(label_key)), self._tid
|
||||
return self.sid, self.name, self.tid
|
||||
|
||||
def __str__(self):
|
||||
"""Return the edge '(sid)->(tid)'."""
|
||||
return f"({self._sid})->({self._tid})"
|
||||
return f"({self._sid})-[{self._name}]->({self._tid})"
|
||||
|
||||
|
||||
class Graph(ABC):
|
||||
@@ -177,8 +200,8 @@ class Graph(ABC):
|
||||
"""Delete vertices and their neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
"""Delete edges(sid -> tid) matches props."""
|
||||
def del_edges(self, sid: str, tid: str, name: str, **props):
|
||||
"""Delete edges(sid -[name]-> tid) matches props."""
|
||||
|
||||
@abstractmethod
|
||||
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
||||
@@ -203,19 +226,19 @@ class Graph(ABC):
|
||||
def format(self) -> str:
|
||||
"""Format graph data to string."""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
|
||||
|
||||
class MemoryGraph(Graph):
|
||||
"""Graph class."""
|
||||
|
||||
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
|
||||
def __init__(self):
|
||||
"""Initialize MemoryGraph with vertex label and edge label."""
|
||||
assert edge_label, "Edge label is needed"
|
||||
|
||||
# metadata
|
||||
self._vertex_label = vertex_label
|
||||
self._edge_label = edge_label
|
||||
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
|
||||
self._edge_prop_keys = {edge_label}
|
||||
self._vertex_prop_keys = set()
|
||||
self._edge_prop_keys = set()
|
||||
self._edge_count = 0
|
||||
|
||||
# init vertices, out edges, in edges index
|
||||
@@ -223,26 +246,6 @@ class MemoryGraph(Graph):
|
||||
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
||||
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
@property
|
||||
def vertex_label(self):
|
||||
"""Return the label for vertices."""
|
||||
return self._vertex_label
|
||||
|
||||
@property
|
||||
def edge_label(self):
|
||||
"""Return the label for edges."""
|
||||
return self._edge_label
|
||||
|
||||
@property
|
||||
def vertex_prop_keys(self):
|
||||
"""Return a set of property keys for vertices."""
|
||||
return self._vertex_prop_keys
|
||||
|
||||
@property
|
||||
def edge_prop_keys(self):
|
||||
"""Return a set of property keys for edges."""
|
||||
return self._edge_prop_keys
|
||||
|
||||
@property
|
||||
def vertex_count(self):
|
||||
"""Return the number of vertices in the graph."""
|
||||
@@ -256,7 +259,10 @@ class MemoryGraph(Graph):
|
||||
def upsert_vertex(self, vertex: Vertex):
|
||||
"""Insert or update a vertex based on its ID."""
|
||||
if vertex.vid in self._vs:
|
||||
self._vs[vertex.vid].props.update(vertex.props)
|
||||
if isinstance(self._vs[vertex.vid], IdVertex):
|
||||
self._vs[vertex.vid] = vertex
|
||||
else:
|
||||
self._vs[vertex.vid].props.update(vertex.props)
|
||||
else:
|
||||
self._vs[vertex.vid] = vertex
|
||||
|
||||
@@ -265,9 +271,6 @@ class MemoryGraph(Graph):
|
||||
|
||||
def append_edge(self, edge: Edge):
|
||||
"""Append an edge if it doesn't exist; requires edge label."""
|
||||
if self.edge_label not in edge.props.keys():
|
||||
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
|
||||
|
||||
sid = edge.sid
|
||||
tid = edge.tid
|
||||
|
||||
@@ -275,8 +278,8 @@ class MemoryGraph(Graph):
|
||||
return False
|
||||
|
||||
# init vertex index
|
||||
self._vs.setdefault(sid, Vertex(sid))
|
||||
self._vs.setdefault(tid, Vertex(tid))
|
||||
self._vs.setdefault(sid, IdVertex(sid))
|
||||
self._vs.setdefault(tid, IdVertex(tid))
|
||||
|
||||
# update edge index
|
||||
self._oes[sid][tid].add(edge)
|
||||
@@ -346,18 +349,19 @@ class MemoryGraph(Graph):
|
||||
self.del_neighbor_edges(vid, Direction.BOTH)
|
||||
self._vs.pop(vid, None)
|
||||
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
def del_edges(self, sid: str, tid: str, name: str, **props):
|
||||
"""Delete edges."""
|
||||
old_edge_cnt = len(self._oes[sid][tid])
|
||||
|
||||
if not props:
|
||||
self._edge_count -= old_edge_cnt
|
||||
self._oes[sid].pop(tid, None)
|
||||
self._ies[tid].pop(sid, None)
|
||||
return
|
||||
|
||||
def remove_matches(es):
|
||||
return set(filter(lambda e: not e.has_props(**props), es))
|
||||
return set(
|
||||
filter(
|
||||
lambda e: not (
|
||||
(name == e.name if name else True) and e.has_props(**props)
|
||||
),
|
||||
es,
|
||||
)
|
||||
)
|
||||
|
||||
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
|
||||
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
|
||||
@@ -439,12 +443,10 @@ class MemoryGraph(Graph):
|
||||
"schema": [
|
||||
{
|
||||
"type": "VERTEX",
|
||||
"label": f"{self._vertex_label}",
|
||||
"properties": [{"name": k} for k in self._vertex_prop_keys],
|
||||
},
|
||||
{
|
||||
"type": "EDGE",
|
||||
"label": f"{self._edge_label}",
|
||||
"properties": [{"name": k} for k in self._edge_prop_keys],
|
||||
},
|
||||
]
|
||||
@@ -452,14 +454,30 @@ class MemoryGraph(Graph):
|
||||
|
||||
def format(self) -> str:
|
||||
"""Format graph to string."""
|
||||
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
|
||||
vs_str = "\n".join(v.format() for v in self.vertices())
|
||||
es_str = "\n".join(
|
||||
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
|
||||
f"{e.format(self.edge_label)}"
|
||||
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
|
||||
f"{self.get_vertex(e.sid).format(concise=True)}"
|
||||
f"{e.format()}"
|
||||
f"{self.get_vertex(e.tid).format(concise=True)}"
|
||||
for e in self.edges()
|
||||
)
|
||||
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
# clean metadata
|
||||
self._vertex_prop_keys.clear()
|
||||
self._edge_prop_keys.clear()
|
||||
self._edge_count = 0
|
||||
|
||||
# clean data and index
|
||||
self._vs.clear()
|
||||
self._oes.clear()
|
||||
self._ies.clear()
|
||||
|
||||
def graphviz(self, name="g"):
|
||||
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
|
||||
@@ -468,7 +486,7 @@ class MemoryGraph(Graph):
|
||||
g.add_node(vertex.vid)
|
||||
|
||||
for edge in self.edges():
|
||||
triplet = edge.triplet(self.edge_label)
|
||||
triplet = edge.triplet()
|
||||
g.add_edge(triplet[0], triplet[2], label=triplet[1])
|
||||
|
||||
digraph = nx.nx_agraph.to_agraph(g).to_string()
|
||||
|
@@ -1,9 +1,9 @@
|
||||
"""Graph store base class."""
|
||||
"""Memory graph store."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
|
||||
|
||||
@@ -15,32 +15,51 @@ class MemoryGraphStoreConfig(GraphStoreConfig):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
edge_name_key: str = Field(
|
||||
default="label",
|
||||
description="The label of edge name, `label` by default.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryGraphStore(GraphStoreBase):
|
||||
"""Memory graph store."""
|
||||
|
||||
def __init__(self, graph_store_config: MemoryGraphStoreConfig):
|
||||
"""Initialize MemoryGraphStore with a memory graph."""
|
||||
self._edge_name_key = graph_store_config.edge_name_key
|
||||
self._graph = MemoryGraph(edge_label=self._edge_name_key)
|
||||
self._graph_store_config = graph_store_config
|
||||
self._graph = MemoryGraph()
|
||||
|
||||
def get_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store_config
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
raise NotImplementedError("Memory graph store does not have edge type")
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert a triplet into the graph."""
|
||||
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
|
||||
self._graph.append_edge(Edge(sub, obj, rel))
|
||||
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
for vertex in graph.vertices():
|
||||
self._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph.append_edge(edge)
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Retrieve triplets originating from a subject."""
|
||||
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
|
||||
return [(e.get_prop(self._edge_name_key), e.tid) for e in subgraph.edges()]
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete a specific triplet from the graph."""
|
||||
self._graph.del_edges(sub, obj, **{self._edge_name_key: rel})
|
||||
self._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
self._graph.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
@@ -50,7 +69,7 @@ class MemoryGraphStore(GraphStoreBase):
|
||||
"""Return the graph schema as a JSON string."""
|
||||
return json.dumps(self._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Return self."""
|
||||
if not limit:
|
||||
return self._graph
|
||||
@@ -79,3 +98,7 @@ class MemoryGraphStore(GraphStoreBase):
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
raise NotImplementedError("Query memory graph not allowed")
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
raise NotImplementedError("Stream query memory graph not allowed")
|
||||
|
@@ -1,10 +1,8 @@
|
||||
"""Neo4j vector store."""
|
||||
"""Neo4j store."""
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Graph, MemoryGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,46 +17,3 @@ class Neo4jStore(GraphStoreBase):
|
||||
"""Neo4j graph store."""
|
||||
|
||||
# todo: add neo4j implementation
|
||||
|
||||
def __init__(self, graph_store_config: Neo4jStoreConfig):
|
||||
"""Initialize the Neo4jStore with connection details."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert triplets."""
|
||||
pass
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
return []
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplets."""
|
||||
pass
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
pass
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
return ""
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
return MemoryGraph()
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return MemoryGraph()
|
||||
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
return MemoryGraph()
|
||||
|
@@ -1,12 +1,14 @@
|
||||
"""TuGraph vector store."""
|
||||
"""TuGraph store."""
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Generator, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.datasource.conn_tugraph import TuGraphConnector
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,20 +31,24 @@ class TuGraphStoreConfig(GraphStoreConfig):
|
||||
description="login username",
|
||||
)
|
||||
password: str = Field(
|
||||
default="123456",
|
||||
default="73@TuGraph",
|
||||
description="login password",
|
||||
)
|
||||
vertex_type: str = Field(
|
||||
default="entity",
|
||||
description="The type of graph vertex, `entity` by default.",
|
||||
description="The type of vertex, `entity` by default.",
|
||||
)
|
||||
edge_type: str = Field(
|
||||
default="relation",
|
||||
description="The type of graph edge, `relation` by default.",
|
||||
description="The type of edge, `relation` by default.",
|
||||
)
|
||||
edge_name_key: str = Field(
|
||||
default="label",
|
||||
description="The label of edge name, `label` by default.",
|
||||
plugin_names: List[str] = Field(
|
||||
default=["leiden"],
|
||||
description=(
|
||||
"Plugins need to be loaded when initialize TuGraph, "
|
||||
"code: https://github.com/TuGraph-family"
|
||||
"/dbgpt-tugraph-plugins/tree/master/cpp"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -51,20 +57,23 @@ class TuGraphStore(GraphStoreBase):
|
||||
|
||||
def __init__(self, config: TuGraphStoreConfig) -> None:
|
||||
"""Initialize the TuGraphStore with connection details."""
|
||||
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
|
||||
self._node_label = (
|
||||
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
|
||||
self._config = config
|
||||
self._host = os.getenv("TUGRAPH_HOST", config.host)
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
|
||||
self._summary_enabled = (
|
||||
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
|
||||
or config.summary_enabled
|
||||
)
|
||||
self._edge_label = (
|
||||
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
|
||||
)
|
||||
self.edge_name_key = (
|
||||
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
|
||||
self._plugin_names = (
|
||||
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
|
||||
or config.plugin_names
|
||||
)
|
||||
self._graph_name = config.name
|
||||
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
|
||||
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
|
||||
|
||||
self.conn = TuGraphConnector.from_uri_db(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
@@ -72,35 +81,197 @@ class TuGraphStore(GraphStoreBase):
|
||||
pwd=self._password,
|
||||
db_name=config.name,
|
||||
)
|
||||
self.conn.create_graph(graph_name=config.name)
|
||||
|
||||
self._create_graph(config.name)
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return self._vertex_type
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return self._edge_type
|
||||
|
||||
def _create_graph(self, graph_name: str):
|
||||
self.conn.create_graph(graph_name=graph_name)
|
||||
self._create_schema()
|
||||
if self._summary_enabled:
|
||||
self._upload_plugin()
|
||||
|
||||
def _check_label(self, elem_type: str):
|
||||
result = self.conn.get_table_names()
|
||||
if elem_type == "vertex":
|
||||
return self._node_label in result["vertex_tables"]
|
||||
return self._vertex_type in result["vertex_tables"]
|
||||
if elem_type == "edge":
|
||||
return self._edge_label in result["edge_tables"]
|
||||
return self._edge_type in result["edge_tables"]
|
||||
|
||||
def _add_vertex_index(self, field_name):
|
||||
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
|
||||
self.conn.run(gql)
|
||||
|
||||
def _upload_plugin(self):
|
||||
gql = "CALL db.plugin.listPlugin('CPP','v1')"
|
||||
result = self.conn.run(gql)
|
||||
result_names = [
|
||||
json.loads(record["plugin_description"])["name"] for record in result
|
||||
]
|
||||
missing_plugins = [
|
||||
name for name in self._plugin_names if name not in result_names
|
||||
]
|
||||
|
||||
if len(missing_plugins):
|
||||
for name in missing_plugins:
|
||||
try:
|
||||
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
|
||||
get_plugin_binary_path,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"dbgpt-tugraph-plugins is not installed, "
|
||||
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
|
||||
"https://pypi.org/simple"
|
||||
)
|
||||
plugin_path = get_plugin_binary_path("leiden")
|
||||
with open(plugin_path, "rb") as f:
|
||||
content = f.read()
|
||||
content = base64.b64encode(content).decode()
|
||||
gql = (
|
||||
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
|
||||
"'SO', '{name} Plugin', false, 'v1')"
|
||||
)
|
||||
self.conn.run(gql)
|
||||
|
||||
def _create_schema(self):
|
||||
if not self._check_label("vertex"):
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._node_label}', "
|
||||
f"'id', ['id',string,false])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
if self._summary_enabled:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false],"
|
||||
f"['_document_id',string,true],"
|
||||
f"['_chunk_id',string,true],"
|
||||
f"['_community_id',string,true],"
|
||||
f"['description',string,true])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
self._add_vertex_index("_community_id")
|
||||
else:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
|
||||
if not self._check_label("edge"):
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_label}', '[["{self._node_label}",
|
||||
"{self._node_label}"]]', ["id",STRING,false])"""
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false])"""
|
||||
if self._summary_enabled:
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false],
|
||||
["description",STRING,true])"""
|
||||
self.conn.run(create_edge_gql)
|
||||
|
||||
def _format_query_data(self, data, white_prop_list: List[str]):
|
||||
nodes_list = []
|
||||
rels_list: List[Any] = []
|
||||
_white_list = white_prop_list
|
||||
from neo4j import graph
|
||||
|
||||
def get_filtered_properties(properties, white_list):
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
def process_node(node: graph.Node):
|
||||
node_id = node._properties.get("id")
|
||||
node_name = node._properties.get("name")
|
||||
node_properties = get_filtered_properties(node._properties, _white_list)
|
||||
nodes_list.append(
|
||||
{"id": node_id, "name": node_name, "properties": node_properties}
|
||||
)
|
||||
|
||||
def process_relationship(rel: graph.Relationship):
|
||||
name = rel._properties.get("name", "")
|
||||
rel_nodes = rel.nodes
|
||||
src_id = rel_nodes[0]._properties.get("id")
|
||||
dst_id = rel_nodes[1]._properties.get("id")
|
||||
for node in rel_nodes:
|
||||
process_node(node)
|
||||
edge_properties = get_filtered_properties(rel._properties, _white_list)
|
||||
if not any(
|
||||
existing_edge.get("name") == name
|
||||
and existing_edge.get("src_id") == src_id
|
||||
and existing_edge.get("dst_id") == dst_id
|
||||
for existing_edge in rels_list
|
||||
):
|
||||
rels_list.append(
|
||||
{
|
||||
"src_id": src_id,
|
||||
"dst_id": dst_id,
|
||||
"name": name,
|
||||
"properties": edge_properties,
|
||||
}
|
||||
)
|
||||
|
||||
def process_path(path: graph.Path):
|
||||
for rel in path.relationships:
|
||||
process_relationship(rel)
|
||||
|
||||
def process_other(value):
|
||||
if not any(
|
||||
existing_node.get("id") == "json_node" for existing_node in nodes_list
|
||||
):
|
||||
nodes_list.append(
|
||||
{
|
||||
"id": "json_node",
|
||||
"name": "json_node",
|
||||
"properties": {"description": value},
|
||||
}
|
||||
)
|
||||
|
||||
for record in data:
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
process_node(value)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
process_relationship(value)
|
||||
elif isinstance(value, graph.Path):
|
||||
process_path(value)
|
||||
else:
|
||||
process_other(value)
|
||||
nodes = [
|
||||
Vertex(node["id"], node["name"], **node["properties"])
|
||||
for node in nodes_list
|
||||
]
|
||||
rels = [
|
||||
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
|
||||
for edge in rels_list
|
||||
]
|
||||
return {"nodes": nodes, "edges": rels}
|
||||
|
||||
def get_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._config
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
query = (
|
||||
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
|
||||
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.conn.run(query)
|
||||
@@ -117,16 +288,83 @@ class TuGraphStore(GraphStoreBase):
|
||||
rel_escaped = escape_quotes(rel)
|
||||
obj_escaped = escape_quotes(obj)
|
||||
|
||||
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
|
||||
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
|
||||
rel_query = (
|
||||
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
|
||||
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
|
||||
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
|
||||
node_query = f"""CALL db.upsertVertex(
|
||||
'{self._vertex_type}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{self._edge_type}',
|
||||
{{type:"{self._vertex_type}",key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def insert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
if value is not None:
|
||||
return value.replace("'", "").replace('"', "")
|
||||
|
||||
nodes: Iterator[Vertex] = graph.vertices()
|
||||
edges: Iterator[Edge] = graph.edges()
|
||||
node_list = []
|
||||
edge_list = []
|
||||
|
||||
def parser(node_list):
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in node_list
|
||||
]
|
||||
return f"""{', '.join(formatted_nodes)}"""
|
||||
|
||||
for node in nodes:
|
||||
node_list.append(
|
||||
{
|
||||
"id": escape_quotes(node.vid),
|
||||
"name": escape_quotes(node.name),
|
||||
"description": escape_quotes(node.get_prop("description")) or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
)
|
||||
node_query = (
|
||||
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
|
||||
)
|
||||
self.conn.run(query=subj_query)
|
||||
self.conn.run(query=obj_query)
|
||||
self.conn.run(query=rel_query)
|
||||
for edge in edges:
|
||||
edge_list.append(
|
||||
{
|
||||
"sid": escape_quotes(edge.sid),
|
||||
"tid": escape_quotes(edge.tid),
|
||||
"id": escape_quotes(edge.name),
|
||||
"name": escape_quotes(edge.name),
|
||||
"description": escape_quotes(edge.get_prop("description")),
|
||||
}
|
||||
)
|
||||
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
"{self._edge_type}",
|
||||
{{type:"{self._vertex_type}", key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{parser(edge_list)}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.conn.run(gql)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
@@ -135,9 +373,9 @@ class TuGraphStore(GraphStoreBase):
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
|
||||
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
|
||||
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
|
||||
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
|
||||
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.conn.run(query=del_query)
|
||||
|
||||
@@ -148,11 +386,20 @@ class TuGraphStore(GraphStoreBase):
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
raise Exception("limit must be set")
|
||||
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
all_graph.append_edge(edge)
|
||||
return all_graph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
@@ -161,8 +408,11 @@ class TuGraphStore(GraphStoreBase):
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
) -> Graph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if fan is not None:
|
||||
raise ValueError("Fan functionality is not supported at this time.")
|
||||
else:
|
||||
@@ -173,67 +423,88 @@ class TuGraphStore(GraphStoreBase):
|
||||
limit_string = f"LIMIT {limit}"
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{self._node_label})"
|
||||
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
|
||||
f"MATCH p=(n:{self._vertex_type})"
|
||||
f"{rel}(m:{self._vertex_type}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
|
||||
def query(self, query: str, **args) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
|
||||
def _format_paths(paths):
|
||||
formatted_paths = []
|
||||
for path in paths:
|
||||
formatted_path = []
|
||||
nodes = list(path["p"].nodes)
|
||||
rels = list(path["p"].relationships)
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append(nodes[i]._properties["id"])
|
||||
if i < len(rels):
|
||||
formatted_path.append(rels[i]._properties["id"])
|
||||
formatted_paths.append(formatted_path)
|
||||
return formatted_paths
|
||||
|
||||
def _format_query_data(data):
|
||||
node_ids_set = set()
|
||||
rels_set = set()
|
||||
from neo4j import graph
|
||||
|
||||
for record in data:
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
node_ids_set.add(node_id)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
rel_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
src_id = rel_nodes[0]._properties["id"]
|
||||
dst_id = rel_nodes[1]._properties["id"]
|
||||
rels_set.add((src_id, dst_id, prop_id))
|
||||
elif isinstance(value, graph.Path):
|
||||
formatted_paths = _format_paths(data)
|
||||
for path in formatted_paths:
|
||||
for i in range(0, len(path), 2):
|
||||
node_ids_set.add(path[i])
|
||||
if i + 2 < len(path):
|
||||
rels_set.add((path[i], path[i + 2], path[i + 1]))
|
||||
|
||||
nodes = [Vertex(node_id) for node_id in node_ids_set]
|
||||
rels = [
|
||||
Edge(src_id, dst_id, label=prop_id)
|
||||
for (src_id, dst_id, prop_id) in rels_set
|
||||
]
|
||||
return {"nodes": nodes, "edges": rels}
|
||||
|
||||
result = self.conn.run(query=query)
|
||||
graph = _format_query_data(result)
|
||||
white_list = args.get("white_list", [])
|
||||
graph = self._format_query_data(result, white_list)
|
||||
mg = MemoryGraph()
|
||||
for vertex in graph["nodes"]:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in graph["edges"]:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
for record in self.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
rel_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
src_id = rel_nodes[0]._properties["id"]
|
||||
dst_id = rel_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(src_id, dst_id, name=prop_id, description=description)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
if i < len(rels):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
formatted_path[i]["id"],
|
||||
formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex("json_node", name="json_node", description=value)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
|
@@ -19,6 +19,10 @@ class KnowledgeGraphConfig(IndexStoreConfig):
|
||||
class KnowledgeGraphBase(IndexStoreBase, ABC):
|
||||
"""Knowledge graph base class."""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> KnowledgeGraphConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
|
||||
@abstractmethod
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get graph data."""
|
||||
|
1
dbgpt/storage/knowledge_graph/community/__init__.py
Normal file
1
dbgpt/storage/knowledge_graph/community/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Community Module."""
|
73
dbgpt/storage/knowledge_graph/community/base.py
Normal file
73
dbgpt/storage/knowledge_graph/community/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Define Classes about Community."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Community:
|
||||
"""Community class."""
|
||||
|
||||
id: str
|
||||
data: Optional[Graph] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityTree:
|
||||
"""Represents a community tree."""
|
||||
|
||||
|
||||
class CommunityStoreAdapter(ABC):
|
||||
"""Community Store Adapter."""
|
||||
|
||||
def __init__(self, graph_store: GraphStoreBase):
|
||||
"""Initialize Community Store Adapter."""
|
||||
self._graph_store = graph_store
|
||||
|
||||
@property
|
||||
def graph_store(self) -> GraphStoreBase:
|
||||
"""Get graph store."""
|
||||
return self._graph_store
|
||||
|
||||
@abstractmethod
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
|
||||
class CommunityMetastore(ABC):
|
||||
"""Community metastore class."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
@abstractmethod
|
||||
def list(self) -> List[Community]:
|
||||
"""Get all communities."""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str) -> List[Community]:
|
||||
"""Search communities relevant to query."""
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, communities: List[Community]):
|
||||
"""Save communities."""
|
||||
|
||||
@abstractmethod
|
||||
async def truncate(self):
|
||||
"""Truncate all communities."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop community metastore."""
|
@@ -0,0 +1,63 @@
|
||||
"""Builtin Community metastore."""
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, CommunityMetastore
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinCommunityMetastore(CommunityMetastore):
|
||||
"""Builtin Community metastore."""
|
||||
|
||||
def __init__(
|
||||
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
|
||||
):
|
||||
"""Initialize Community metastore."""
|
||||
self._vector_store = vector_store
|
||||
self._rdb_store = rdb_store
|
||||
|
||||
config = self._vector_store.get_config()
|
||||
self._vector_space = config.name
|
||||
self._max_chunks_once_load = config.max_chunks_once_load
|
||||
self._max_threads = config.max_threads
|
||||
self._topk = config.topk
|
||||
self._score_threshold = config.score_threshold
|
||||
|
||||
def get(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
raise NotImplementedError("Get community not allowed")
|
||||
|
||||
def list(self) -> List[Community]:
|
||||
"""Get all communities."""
|
||||
raise NotImplementedError("List communities not allowed")
|
||||
|
||||
async def search(self, query: str) -> List[Community]:
|
||||
"""Search communities relevant to query."""
|
||||
chunks = await self._vector_store.asimilar_search_with_scores(
|
||||
query, self._topk, self._score_threshold
|
||||
)
|
||||
return [Community(id=chunk.chunk_id, summary=chunk.content) for chunk in chunks]
|
||||
|
||||
async def save(self, communities: List[Community]):
|
||||
"""Save communities."""
|
||||
chunks = [
|
||||
Chunk(id=c.id, content=c.summary, metadata={"total": len(communities)})
|
||||
for c in communities
|
||||
]
|
||||
await self._vector_store.aload_document_with_limit(
|
||||
chunks, self._max_chunks_once_load, self._max_threads
|
||||
)
|
||||
logger.info(f"Save {len(communities)} communities")
|
||||
|
||||
async def truncate(self):
|
||||
"""Truncate community metastore."""
|
||||
self._vector_store.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community metastore."""
|
||||
if self._vector_store.vector_name_exists():
|
||||
self._vector_store.delete_vector_name(self._vector_space)
|
83
dbgpt/storage/knowledge_graph/community/community_store.py
Normal file
83
dbgpt/storage/knowledge_graph/community/community_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Define the CommunityStore class."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.community_metastore import (
|
||||
BuiltinCommunityMetastore,
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStore:
|
||||
"""CommunityStore Class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_store_adapter: CommunityStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._community_store_adapter = community_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await (self._community_store_adapter.discover_communities())
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await (
|
||||
self._community_store_adapter.get_community(community_id)
|
||||
)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
|
||||
community.summary = await (
|
||||
self._community_summarizer.summarize(graph=graph)
|
||||
)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
)
|
||||
|
||||
# truncate then save new summaries
|
||||
await self._meta_store.truncate()
|
||||
await self._meta_store.save(communities)
|
||||
|
||||
async def search_communities(self, query: str) -> List[Community]:
|
||||
"""Search communities."""
|
||||
return await self._meta_store.search(query)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate community store."""
|
||||
logger.info("Truncate community metastore")
|
||||
self._meta_store.truncate()
|
||||
|
||||
logger.info("Truncate community summarizer")
|
||||
self._community_summarizer.truncate()
|
||||
|
||||
logger.info("Truncate graph")
|
||||
self._community_store_adapter.graph_store.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community store."""
|
||||
logger.info("Remove community metastore")
|
||||
self._meta_store.drop()
|
||||
|
||||
logger.info("Remove community summarizer")
|
||||
self._community_summarizer.drop()
|
||||
|
||||
logger.info("Remove graph")
|
||||
self._community_store_adapter.graph_store.drop()
|
30
dbgpt/storage/knowledge_graph/community/factory.py
Normal file
30
dbgpt/storage/knowledge_graph/community/factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""CommunityStoreAdapter factory."""
|
||||
import logging
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
|
||||
TuGraphCommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStoreAdapterFactory:
|
||||
"""Factory for community store adapter."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
|
||||
"""Create a CommunityStoreAdapter instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
"""
|
||||
if isinstance(graph_store, TuGraphStore):
|
||||
return TuGraphCommunityStoreAdapter(graph_store)
|
||||
else:
|
||||
raise Exception(
|
||||
"create community store adapter for %s failed",
|
||||
graph_store.__class__.__name__,
|
||||
)
|
52
dbgpt/storage/knowledge_graph/community/tugraph_adapter.py
Normal file
52
dbgpt/storage/knowledge_graph/community/tugraph_adapter.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self._graph_store.query(
|
||||
"CALL db.plugin.callPlugin"
|
||||
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
|
||||
f"[r:{self._graph_store.get_edge_type()}]-"
|
||||
f"(m:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self._graph_store.aquery(query)
|
||||
all_edge_graph = self._graph_store.aquery(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
373
dbgpt/storage/knowledge_graph/community_summary.py
Normal file
373
dbgpt/storage/knowledge_graph/community_summary.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Define the CommunitySummaryKnowledgeGraph."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
|
||||
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
|
||||
from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.factory import VectorStoreFactory
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
"""Community summary knowledge graph config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
vector_store_type: str = Field(
|
||||
default="Chroma", description="The type of vector store."
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The user of vector store, if not set, will use the default user.",
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of knowledge graph extract",
|
||||
)
|
||||
extract_score_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Recall score of knowledge graph extract",
|
||||
)
|
||||
community_topk: int = Field(
|
||||
default=50,
|
||||
description="Topk of community search in knowledge graph",
|
||||
)
|
||||
community_score_threshold: float = Field(
|
||||
default=0.0,
|
||||
description="Recall score of community search in knowledge graph",
|
||||
)
|
||||
|
||||
|
||||
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
"""Community summary knowledge graph class."""
|
||||
|
||||
def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
|
||||
"""Initialize community summary knowledge graph class."""
|
||||
super().__init__(config)
|
||||
self._config = config
|
||||
|
||||
self._vector_store_type = os.getenv(
|
||||
"VECTOR_STORE_TYPE", config.vector_store_type
|
||||
)
|
||||
self._extract_topk = int(
|
||||
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
|
||||
)
|
||||
self._extract_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
|
||||
config.extract_score_threshold,
|
||||
)
|
||||
)
|
||||
self._community_topk = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE", config.community_topk
|
||||
)
|
||||
)
|
||||
self._community_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE",
|
||||
config.community_score_threshold,
|
||||
)
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
cfg.max_chunks_once_load = config.max_chunks_once_load
|
||||
cfg.max_threads = config.max_threads
|
||||
cfg.user = config.user
|
||||
cfg.password = config.password
|
||||
cfg.topk = self._extract_topk
|
||||
cfg.score_threshold = self._extract_score_threshold
|
||||
|
||||
self._graph_extractor = GraphExtractor(
|
||||
self._llm_client,
|
||||
self._model_name,
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
config.name + "_CHUNK_HISTORY",
|
||||
extractor_configure,
|
||||
),
|
||||
)
|
||||
|
||||
def community_store_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
cfg.max_chunks_once_load = config.max_chunks_once_load
|
||||
cfg.max_threads = config.max_threads
|
||||
cfg.user = config.user
|
||||
cfg.password = config.password
|
||||
cfg.topk = self._community_topk
|
||||
cfg.score_threshold = self._community_score_threshold
|
||||
|
||||
self._community_store = CommunityStore(
|
||||
CommunityStoreAdapterFactory.create(self._graph_store),
|
||||
CommunitySummarizer(self._llm_client, self._model_name),
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
config.name + "_COMMUNITY_SUMMARY",
|
||||
community_store_configure,
|
||||
),
|
||||
)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph."""
|
||||
# todo add doc node
|
||||
for chunk in chunks:
|
||||
# todo add chunk node
|
||||
# todo add relation doc-chunk
|
||||
|
||||
# extract graphs and save
|
||||
graphs = await self._graph_extractor.extract(chunk.content)
|
||||
for graph in graphs:
|
||||
self._graph_store.insert_graph(graph)
|
||||
|
||||
# build communities and save
|
||||
await self._community_store.build_communities()
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
topk,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve relevant community summaries."""
|
||||
# global search: retrieve relevant community summaries
|
||||
communities = await self._community_store.search_communities(text)
|
||||
summaries = [
|
||||
f"Section {i + 1}:\n{community.summary}"
|
||||
for i, community in enumerate(communities)
|
||||
]
|
||||
context = "\n".join(summaries) if summaries else ""
|
||||
|
||||
# local search: extract keywords and explore subgraph
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
|
||||
if not summaries and not subgraph:
|
||||
return []
|
||||
|
||||
# merge search results into context
|
||||
content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph)
|
||||
return [Chunk(content=content)]
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate knowledge graph."""
|
||||
logger.info("Truncate community store")
|
||||
self._community_store.truncate()
|
||||
logger.info("Truncate keyword extractor")
|
||||
self._keyword_extractor.truncate()
|
||||
logger.info("Truncate triplet extractor")
|
||||
self._graph_extractor.truncate()
|
||||
return [self._config.name]
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete knowledge graph."""
|
||||
logger.info("Drop community store")
|
||||
self._community_store.drop()
|
||||
|
||||
logger.info("Drop keyword extractor")
|
||||
self._keyword_extractor.drop()
|
||||
|
||||
logger.info("Drop triplet extractor")
|
||||
self._graph_extractor.drop()
|
||||
|
||||
|
||||
HYBRID_SEARCH_PT_CN = (
|
||||
"## 角色\n"
|
||||
"你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,"
|
||||
"准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。"
|
||||
"\n"
|
||||
"## 技能\n"
|
||||
"### 技能 1: 上下文理解\n"
|
||||
"- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n"
|
||||
"- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n"
|
||||
"- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。"
|
||||
"### 技能 2: 知识图谱理解\n"
|
||||
"- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息"
|
||||
"和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n"
|
||||
"```"
|
||||
"* 实体信息格式:\n"
|
||||
"- (实体名)\n"
|
||||
"- (实体名:实体描述)\n"
|
||||
"- (实体名:实体属性表)\n"
|
||||
"- (文本块ID:文档块内容)\n"
|
||||
"- (目录ID:目录名)\n"
|
||||
"- (文档ID:文档名称)\n"
|
||||
"\n"
|
||||
"* 关系信息的格式:\n"
|
||||
"- (来源实体名)-[关系名]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系描述]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n"
|
||||
"- (文本块实体)-[包含]->(实体名)\n"
|
||||
"- (目录ID)-[包含]->(文本块实体)\n"
|
||||
"- (目录ID)-[包含]->(子目录ID)\n"
|
||||
"- (文档ID)-[包含]->(文本块实体)\n"
|
||||
"- (文档ID)-[包含]->(目录ID)\n"
|
||||
"```"
|
||||
"- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。"
|
||||
"- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n"
|
||||
"\n"
|
||||
"## 约束条件\n"
|
||||
"- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n"
|
||||
"- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。"
|
||||
"- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n"
|
||||
"- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n"
|
||||
"- 避免使用停用词和过于常见的词汇。\n"
|
||||
"\n"
|
||||
"## 参考案例\n"
|
||||
"```\n"
|
||||
"[上下文]:\n"
|
||||
"Section 1:\n"
|
||||
"菲尔・贾伯的大儿子叫雅各布・贾伯。\n"
|
||||
"Section 2:\n"
|
||||
"菲尔・贾伯的小儿子叫比尔・贾伯。\n"
|
||||
"[知识图谱]:\n"
|
||||
"Entities:\n"
|
||||
"(菲尔・贾伯#菲尔兹咖啡创始人)\n"
|
||||
"(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n"
|
||||
"(雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(美国多地#菲尔兹咖啡的扩展地区)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n"
|
||||
"(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n"
|
||||
"(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n"
|
||||
"(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n"
|
||||
"\n"
|
||||
"[上下文]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[知识图谱]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
HYBRID_SEARCH_PT_EN = (
|
||||
"## Role\n"
|
||||
"You excel at combining the information provided in the [Context] with "
|
||||
"information from the [KnowledgeGraph] to accurately and appropriately "
|
||||
"answer user questions, ensuring that you do not output information "
|
||||
"unrelated to the context and knowledge graph.\n"
|
||||
"\n"
|
||||
"## Skills\n"
|
||||
"### Skill 1: Context Understanding\n"
|
||||
"- Accurately understand the information provided in the [Context], "
|
||||
"which may be divided into several sections.\n"
|
||||
"- Each section in the context will start with [Section] "
|
||||
"and may be numbered as needed.\n"
|
||||
"- The context provides a summary description most relevant to the user’s "
|
||||
"question, and it should be used wisely."
|
||||
"### Skill 2: Knowledge Graph Understanding\n"
|
||||
"- Accurately identify entity information in the [Entities:] section and "
|
||||
"relationship information in the [Relationships:] section "
|
||||
"of the [KnowledgeGraph]. The general format for entity "
|
||||
"and relationship information is:\n"
|
||||
"```"
|
||||
"* Entity Information Format:\n"
|
||||
"- (entity_name)\n"
|
||||
"- (entity_name: entity_description)\n"
|
||||
"- (entity_name: entity_property_map)\n"
|
||||
"- (chunk_id: chunk_content)\n"
|
||||
"- (catalog_id: catalog_name)\n"
|
||||
"- (document_id: document_name)\n"
|
||||
"\n"
|
||||
"* Relationship Information Format:\n"
|
||||
"- (source_entity_name)-[relationship_name]->(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_description]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_property_map]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (chunk_id)-[Contains]->(entity_name)\n"
|
||||
"- (catalog_id)-[Contains]->(chunk_id)\n"
|
||||
"- (catalog_id)-[Contains]->(sub_catalog_id)\n"
|
||||
"- (document_id)-[Contains]->(chunk_id)\n"
|
||||
"- (document_id)-[Contains]->(catalog_id)\n"
|
||||
"```"
|
||||
"- Correctly associate entity names/IDs in the relationship information "
|
||||
"with entity information to restore the graph structure."
|
||||
"- Use the information expressed by the graph structure as detailed "
|
||||
"context for the user's query to assist in generating better answers.\n"
|
||||
"\n"
|
||||
"## Constraints\n"
|
||||
"- Don't describe your thought process in the answer, provide the answer "
|
||||
"to the user's question directly without generating irrelevant information."
|
||||
"- If the [KnowledgeGraph] does not provide information, you should answer "
|
||||
"the question based on the information provided in the [Context]."
|
||||
"- Ensure to write in the third person, responding to questions from "
|
||||
"an objective perspective based on the information combined from the "
|
||||
"[Context] and the [KnowledgeGraph].\n"
|
||||
"- If the provided information is contradictory, resolve the "
|
||||
"contradictions and provide a single, coherent description.\n"
|
||||
"- Avoid using stop words and overly common vocabulary.\n"
|
||||
"\n"
|
||||
"## Reference Example\n"
|
||||
"```\n"
|
||||
"[Context]:\n"
|
||||
"Section 1:\n"
|
||||
"Phil Schiller's eldest son is Jacob Schiller.\n"
|
||||
"Section 2:\n"
|
||||
"Phil Schiller's youngest son is Bill Schiller.\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"Entities:\n"
|
||||
"(Phil Jaber#Founder of Philz Coffee)\n"
|
||||
"(Philz Coffee#Coffee brand founded in Berkeley, California)\n"
|
||||
"(Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Multiple locations in the USA#Expansion regions of Philz Coffee)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(Phil Jaber#Created#Philz Coffee"
|
||||
"#Founded in Berkeley, California in 1978)\n"
|
||||
"(Philz Coffee#Located in#Berkeley, California"
|
||||
"#Founding location of Philz Coffee)\n"
|
||||
"(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n"
|
||||
"(Philz Coffee#Expanded to#Multiple locations in the USA"
|
||||
"#Expansion regions of Philz Coffee)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"The following information from the [Context] and [KnowledgeGraph] can "
|
||||
"help you better answer user questions.\n"
|
||||
"\n"
|
||||
"[Context]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
@@ -36,8 +36,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def __init__(self, config: BuiltinKnowledgeGraphConfig):
|
||||
"""Create builtin knowledge graph instance."""
|
||||
self._config = config
|
||||
super().__init__()
|
||||
self._config = config
|
||||
|
||||
self._llm_client = config.llm_client
|
||||
if not self._llm_client:
|
||||
raise ValueError("No llm client provided.")
|
||||
@@ -45,17 +46,19 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
self._model_name = config.model_name
|
||||
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
|
||||
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
|
||||
self._graph_store_type = (
|
||||
os.getenv("GRAPH_STORE_TYPE", "TuGraph") or config.graph_store_type
|
||||
)
|
||||
self._graph_store = self.__init_graph_store(config)
|
||||
|
||||
def __init_graph_store(self, config) -> GraphStoreBase:
|
||||
def configure(cfg: GraphStoreConfig):
|
||||
cfg.name = self._config.name
|
||||
cfg.embedding_fn = self._config.embedding_fn
|
||||
cfg.name = config.name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
|
||||
self._graph_store: GraphStoreBase = GraphStoreFactory.create(
|
||||
self._graph_store_type, configure
|
||||
)
|
||||
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type
|
||||
return GraphStoreFactory.create(graph_store_type, configure)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist triplets to graph store."""
|
||||
@@ -113,35 +116,59 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
# extract keywords and explore graph store
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
|
||||
if not subgraph:
|
||||
return []
|
||||
|
||||
content = (
|
||||
"The following vertices and edges data after [Subgraph Data] "
|
||||
"are retrieved from the knowledge graph based on the keywords:\n"
|
||||
f"Keywords:\n{','.join(keywords)}\n"
|
||||
"The following entities and relationships provided after "
|
||||
"[Subgraph] are retrieved from the knowledge graph "
|
||||
"based on the keywords:\n"
|
||||
f"\"{','.join(keywords)}\".\n"
|
||||
"---------------------\n"
|
||||
"You can refer to the sample vertices and edges to understand "
|
||||
"the real knowledge graph data provided by [Subgraph Data].\n"
|
||||
"Sample vertices:\n"
|
||||
"The following examples after [Entities] and [Relationships] that "
|
||||
"can help you understand the data format of the knowledge graph, "
|
||||
"but do not use them in the answer.\n"
|
||||
"[Entities]:\n"
|
||||
"(alice)\n"
|
||||
"(bob:{age:28})\n"
|
||||
'(carry:{age:18;role:"teacher"})\n\n'
|
||||
"Sample edges:\n"
|
||||
"[Relationships]:\n"
|
||||
"(alice)-[reward]->(alice)\n"
|
||||
'(alice)-[notify:{method:"email"}]->'
|
||||
'(carry:{age:18;role:"teacher"})\n'
|
||||
'(bob:{age:28})-[teach:{course:"math";hour:180}]->(alice)\n'
|
||||
"---------------------\n"
|
||||
f"Subgraph Data:\n{subgraph.format()}\n"
|
||||
f"[Subgraph]:\n{subgraph}\n"
|
||||
)
|
||||
return [Chunk(content=content, metadata=subgraph.schema())]
|
||||
return [Chunk(content=content)]
|
||||
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Query graph."""
|
||||
return self._graph_store.get_full_graph(limit)
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate knowledge graph."""
|
||||
logger.info(f"Truncate graph {self._config.name}")
|
||||
self._graph_store.truncate()
|
||||
|
||||
logger.info("Truncate keyword extractor")
|
||||
self._keyword_extractor.truncate()
|
||||
|
||||
logger.info("Truncate triplet extractor")
|
||||
self._triplet_extractor.truncate()
|
||||
|
||||
return [self._config.name]
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete vector name."""
|
||||
logger.info(f"Remove graph index {index_name}")
|
||||
logger.info(f"Drop graph {index_name}")
|
||||
self._graph_store.drop()
|
||||
|
||||
logger.info("Drop keyword extractor")
|
||||
self._keyword_extractor.drop()
|
||||
|
||||
logger.info("Drop triplet extractor")
|
||||
self._triplet_extractor.drop()
|
||||
|
@@ -1,12 +1,8 @@
|
||||
"""OpenSPG class."""
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,29 +17,3 @@ class OpenSPG(KnowledgeGraphBase):
|
||||
"""OpenSPG class."""
|
||||
|
||||
# todo: add OpenSPG implementation
|
||||
|
||||
def __init__(self, config: OpenSPGConfig):
|
||||
"""Initialize the OpenSPG with config details."""
|
||||
pass
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document."""
|
||||
return []
|
||||
|
||||
def similar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
topk,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar with scores."""
|
||||
return []
|
||||
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Query graph."""
|
||||
return MemoryGraph()
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete vector name."""
|
||||
pass
|
||||
|
@@ -56,6 +56,15 @@ def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
|
||||
return BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig
|
||||
|
||||
|
||||
def _import_community_summary_knowledge_graph() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.knowledge_graph.community_summary import (
|
||||
CommunitySummaryKnowledgeGraph,
|
||||
CommunitySummaryKnowledgeGraphConfig,
|
||||
)
|
||||
|
||||
return CommunitySummaryKnowledgeGraph, CommunitySummaryKnowledgeGraphConfig
|
||||
|
||||
|
||||
def _import_openspg() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.knowledge_graph.open_spg import OpenSPG, OpenSPGConfig
|
||||
|
||||
@@ -86,6 +95,8 @@ def __getattr__(name: str) -> Tuple[Type, Type]:
|
||||
return _import_elastic()
|
||||
elif name == "KnowledgeGraph":
|
||||
return _import_builtin_knowledge_graph()
|
||||
elif name == "CommunitySummaryKnowledgeGraph":
|
||||
return _import_community_summary_knowledge_graph()
|
||||
elif name == "OpenSPG":
|
||||
return _import_openspg()
|
||||
elif name == "FullText":
|
||||
@@ -103,7 +114,7 @@ __vector_store__ = [
|
||||
"ElasticSearch",
|
||||
]
|
||||
|
||||
__knowledge_graph__ = ["KnowledgeGraph", "OpenSPG"]
|
||||
__knowledge_graph__ = ["KnowledgeGraph", "CommunitySummaryKnowledgeGraph", "OpenSPG"]
|
||||
|
||||
__document_store__ = ["FullText"]
|
||||
|
||||
|
@@ -99,6 +99,14 @@ class VectorStoreConfig(IndexStoreConfig):
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of vector search",
|
||||
)
|
||||
score_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Recall score of vector search",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(IndexStoreBase, ABC):
|
||||
@@ -108,6 +116,10 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
"""Initialize vector store."""
|
||||
super().__init__(executor)
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> VectorStoreConfig:
|
||||
"""Get the vector store config."""
|
||||
|
||||
def filter_by_score_threshold(
|
||||
self, chunks: List[Chunk], score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
@@ -126,7 +138,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
metadata=chunk.metadata,
|
||||
content=chunk.content,
|
||||
score=chunk.score,
|
||||
chunk_id=str(id),
|
||||
chunk_id=chunk.chunk_id,
|
||||
)
|
||||
for chunk in chunks
|
||||
if chunk.score >= score_threshold
|
||||
|
@@ -63,6 +63,8 @@ class ChromaStore(VectorStoreBase):
|
||||
vector_store_config(ChromaVectorConfig): vector store config.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
@@ -89,6 +91,10 @@ class ChromaStore(VectorStoreBase):
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
|
||||
def get_config(self) -> ChromaVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
@@ -100,10 +106,16 @@ class ChromaStore(VectorStoreBase):
|
||||
filters=filters,
|
||||
)
|
||||
return [
|
||||
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
|
||||
Chunk(
|
||||
content=chroma_result[0],
|
||||
metadata=chroma_result[1] or {},
|
||||
score=0.0,
|
||||
chunk_id=chroma_result[2],
|
||||
)
|
||||
for chroma_result in zip(
|
||||
chroma_results["documents"][0],
|
||||
chroma_results["metadatas"][0],
|
||||
chroma_results["ids"][0],
|
||||
)
|
||||
]
|
||||
|
||||
@@ -134,12 +146,14 @@ class ChromaStore(VectorStoreBase):
|
||||
content=chroma_result[0],
|
||||
metadata=chroma_result[1] or {},
|
||||
score=(1 - chroma_result[2]),
|
||||
chunk_id=chroma_result[3],
|
||||
)
|
||||
)
|
||||
for chroma_result in zip(
|
||||
chroma_results["documents"][0],
|
||||
chroma_results["metadatas"][0],
|
||||
chroma_results["distances"][0],
|
||||
chroma_results["ids"][0],
|
||||
)
|
||||
]
|
||||
return self.filter_by_score_threshold(chunks, score_threshold)
|
||||
@@ -181,6 +195,20 @@ class ChromaStore(VectorStoreBase):
|
||||
if len(ids) > 0:
|
||||
self._collection.delete(ids=ids)
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate data index_name."""
|
||||
logger.info(f"begin truncate chroma collection:{self._collection.name}")
|
||||
results = self._collection.get()
|
||||
ids = results.get("ids")
|
||||
if ids:
|
||||
self._collection.delete(ids=ids)
|
||||
logger.info(
|
||||
f"truncate chroma collection {self._collection.name} "
|
||||
f"{len(ids)} chunks success"
|
||||
)
|
||||
return ids
|
||||
return []
|
||||
|
||||
def convert_metadata_filters(
|
||||
self,
|
||||
filters: MetadataFilters,
|
||||
|
@@ -126,6 +126,8 @@ class ElasticStore(VectorStoreBase):
|
||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
connect_kwargs = {}
|
||||
elasticsearch_vector_config = vector_store_config.dict()
|
||||
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
||||
@@ -234,6 +236,10 @@ class ElasticStore(VectorStoreBase):
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch connection failed: {e}")
|
||||
|
||||
def get_config(self) -> ElasticsearchVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def load_document(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
|
44
dbgpt/storage/vector_store/factory.py
Normal file
44
dbgpt/storage/vector_store/factory.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Vector store factory."""
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""Factory for vector store."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
vector_store_type: str, vector_space_name: str, vector_store_configure=None
|
||||
) -> VectorStoreBase:
|
||||
"""Create a VectorStore instance.
|
||||
|
||||
Args:
|
||||
- vector_store_type: vector store type Chroma, Milvus, etc.
|
||||
- vector_store_config: vector store config
|
||||
"""
|
||||
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
|
||||
|
||||
try:
|
||||
config = cfg_cls()
|
||||
if vector_store_configure:
|
||||
vector_store_configure(vector_space_name, config)
|
||||
return store_cls(config)
|
||||
except Exception as e:
|
||||
logger.error("create vector store failed: %s", e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
||||
for t in vector_store.__vector_store__:
|
||||
if t.lower() == vector_store_type.lower():
|
||||
store_cls, cfg_cls = getattr(vector_store, t)
|
||||
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
||||
cfg_cls, VectorStoreConfig
|
||||
):
|
||||
return store_cls, cfg_cls
|
||||
raise Exception(f"Vector store {vector_store_type} not supported")
|
@@ -150,6 +150,8 @@ class MilvusStore(VectorStoreBase):
|
||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
try:
|
||||
from pymilvus import connections
|
||||
except ImportError:
|
||||
@@ -363,6 +365,10 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
return res.primary_keys
|
||||
|
||||
def get_config(self) -> MilvusVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in vector database."""
|
||||
batch_size = 500
|
||||
|
@@ -718,6 +718,8 @@ class OceanBaseStore(VectorStoreBase):
|
||||
if vector_store_config.embedding_fn is None:
|
||||
raise ValueError("embedding_fn is required for OceanBaseStore")
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
vector_store_config = vector_store_config.dict()
|
||||
@@ -760,6 +762,10 @@ class OceanBaseStore(VectorStoreBase):
|
||||
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
|
||||
)
|
||||
|
||||
def get_config(self) -> OceanBaseConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
|
||||
) -> List[Chunk]:
|
||||
|
@@ -64,6 +64,8 @@ class PGVectorStore(VectorStoreBase):
|
||||
"Please install the `langchain` package to use the PGVector."
|
||||
)
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
@@ -74,6 +76,10 @@ class PGVectorStore(VectorStoreBase):
|
||||
connection_string=self.connection_string,
|
||||
)
|
||||
|
||||
def get_config(self) -> PGVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
|
@@ -69,6 +69,8 @@ class WeaviateStore(VectorStoreBase):
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.weaviate_url = vector_store_config.weaviate_url
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.vector_name = vector_store_config.name
|
||||
@@ -78,6 +80,10 @@ class WeaviateStore(VectorStoreBase):
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def get_config(self) -> WeaviateVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
|
Reference in New Issue
Block a user