feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)

Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
M1n9X
2024-08-30 21:59:44 +08:00
committed by GitHub
parent 471689ba20
commit 759f7d99cc
59 changed files with 29316 additions and 411 deletions

View File

@@ -5,7 +5,7 @@ from concurrent.futures import Executor, ThreadPoolExecutor
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.rag.index.base import logger
from dbgpt.rag.index.base import IndexStoreConfig, logger
from dbgpt.storage.full_text.base import FullTextStoreBase
from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
@@ -35,6 +35,7 @@ class ElasticDocumentStore(FullTextStoreBase):
This similarity has the following options:
"""
super().__init__()
self._es_config = es_config
from elasticsearch import Elasticsearch
self._es_config = es_config
@@ -94,6 +95,10 @@ class ElasticDocumentStore(FullTextStoreBase):
)
self._executor = executor or ThreadPoolExecutor()
def get_config(self) -> IndexStoreConfig:
"""Get the es store config."""
return self._es_config
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in elasticsearch.

View File

@@ -2,11 +2,11 @@
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.storage.full_text.base import FullTextStoreBase
from dbgpt.storage.vector_store.filters import MetadataFilters
class OpenSearch(IndexStoreBase):
class OpenSearch(FullTextStoreBase):
"""OpenSearch index store."""
def load_document(self, chunks: List[Chunk]) -> List[str]:

View File

@@ -1,7 +1,7 @@
"""Graph store base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
@@ -23,15 +23,35 @@ class GraphStoreConfig(BaseModel):
default=None,
description="The embedding function of graph store, optional.",
)
summary_enabled: bool = Field(
default=False,
description="Enable graph community summary or not.",
)
class GraphStoreBase(ABC):
"""Graph store base class."""
@abstractmethod
def get_config(self) -> GraphStoreConfig:
"""Get the graph store config."""
@abstractmethod
def get_vertex_type(self) -> str:
"""Get the vertex type."""
@abstractmethod
def get_edge_type(self) -> str:
"""Get the edge type."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""
@abstractmethod
def insert_graph(self, graph: Graph):
"""Add graph."""
@abstractmethod
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
@@ -40,6 +60,10 @@ class GraphStoreBase(ABC):
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplet."""
@abstractmethod
def truncate(self):
"""Truncate Graph."""
@abstractmethod
def drop(self):
"""Drop graph."""
@@ -66,3 +90,11 @@ class GraphStoreBase(ABC):
@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""
def aquery(self, query: str, **args) -> Graph:
"""Async execute a query."""
return self.query(query, **args)
@abstractmethod
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""

View File

@@ -1,4 +1,4 @@
"""Connector for vector store."""
"""Graph store factory."""
import logging
from typing import Tuple, Type

View File

@@ -1,4 +1,4 @@
"""Graph store base class."""
"""Graph definition."""
import itertools
import json
import logging
@@ -24,9 +24,15 @@ class Direction(Enum):
class Elem(ABC):
"""Elem class."""
def __init__(self):
def __init__(self, name: Optional[str] = None):
"""Initialize Elem."""
self._props = {}
self._name = name
self._props: Dict[str, Any] = {}
@property
def name(self) -> str:
"""Return the edge label."""
return self._name or ""
@property
def props(self) -> Dict[str, Any]:
@@ -46,14 +52,17 @@ class Elem(ABC):
self._props.pop(key, None)
def has_props(self, **props):
"""Check if the element has the specified properties with the given values."""
"""Check all key-value pairs exist."""
return all(self._props.get(k) == v for k, v in props.items())
@abstractmethod
def format(self, label_key: Optional[str] = None):
def format(self) -> str:
"""Format properties into a string."""
if len(self._props) == 1:
return str(next(iter(self._props.values())))
formatted_props = [
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
]
return f"{{{';'.join(formatted_props)}}}"
@@ -61,9 +70,9 @@ class Elem(ABC):
class Vertex(Elem):
"""Vertex class."""
def __init__(self, vid: str, **props):
def __init__(self, vid: str, name: Optional[str] = None, **props):
"""Initialize Vertex."""
super().__init__()
super().__init__(name)
self._vid = vid
for k, v in props.items():
self.set_prop(k, v)
@@ -73,26 +82,43 @@ class Vertex(Elem):
"""Return the vertex ID."""
return self._vid
def format(self, label_key: Optional[str] = None):
"""Format vertex properties into a string."""
label = self.get_prop(label_key) if label_key else self._vid
props_str = super().format(label_key)
if props_str == "{}":
return f"({label})"
@property
def name(self) -> str:
"""Return the vertex name."""
return super().name or self._vid
def format(self, concise: bool = False):
"""Format vertex into a string."""
name = self._name or self._vid
if concise:
return f"({name})"
if self._props:
return f"({name}:{super().format()})"
else:
return f"({label}:{props_str})"
return f"({name})"
def __str__(self):
"""Return the vertex ID as its string representation."""
return f"({self._vid})"
class IdVertex(Vertex):
"""IdVertex class."""
def __init__(self, vid: str):
"""Initialize Idvertex."""
super().__init__(vid)
class Edge(Elem):
"""Edge class."""
def __init__(self, sid: str, tid: str, **props):
def __init__(self, sid: str, tid: str, name: str, **props):
"""Initialize Edge."""
super().__init__()
assert name, "Edge name is required"
super().__init__(name)
self._sid = sid
self._tid = tid
for k, v in props.items():
@@ -117,23 +143,20 @@ class Edge(Elem):
else:
raise ValueError(f"Get nid of {vid} on {self} failed")
def format(self, label_key: Optional[str] = None):
def format(self):
"""Format the edge properties into a string."""
label = self.get_prop(label_key) if label_key else ""
props_str = super().format(label_key)
if props_str == "{}":
return f"-[{label}]->" if label else "->"
if self._props:
return f"-[{self._name}:{super().format()}]->"
else:
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
return f"-[{self._name}]->"
def triplet(self, label_key: str) -> Tuple[str, str, str]:
def triplet(self) -> Tuple[str, str, str]:
"""Return a triplet."""
assert label_key, "label key is needed"
return self._sid, str(self.get_prop(label_key)), self._tid
return self.sid, self.name, self.tid
def __str__(self):
"""Return the edge '(sid)->(tid)'."""
return f"({self._sid})->({self._tid})"
return f"({self._sid})-[{self._name}]->({self._tid})"
class Graph(ABC):
@@ -177,8 +200,8 @@ class Graph(ABC):
"""Delete vertices and their neighbor edges."""
@abstractmethod
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges(sid -> tid) matches props."""
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges(sid -[name]-> tid) matches props."""
@abstractmethod
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
@@ -203,19 +226,19 @@ class Graph(ABC):
def format(self) -> str:
"""Format graph data to string."""
@abstractmethod
def truncate(self):
"""Truncate graph."""
class MemoryGraph(Graph):
"""Graph class."""
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
def __init__(self):
"""Initialize MemoryGraph with vertex label and edge label."""
assert edge_label, "Edge label is needed"
# metadata
self._vertex_label = vertex_label
self._edge_label = edge_label
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
self._edge_prop_keys = {edge_label}
self._vertex_prop_keys = set()
self._edge_prop_keys = set()
self._edge_count = 0
# init vertices, out edges, in edges index
@@ -223,26 +246,6 @@ class MemoryGraph(Graph):
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@property
def vertex_label(self):
"""Return the label for vertices."""
return self._vertex_label
@property
def edge_label(self):
"""Return the label for edges."""
return self._edge_label
@property
def vertex_prop_keys(self):
"""Return a set of property keys for vertices."""
return self._vertex_prop_keys
@property
def edge_prop_keys(self):
"""Return a set of property keys for edges."""
return self._edge_prop_keys
@property
def vertex_count(self):
"""Return the number of vertices in the graph."""
@@ -256,7 +259,10 @@ class MemoryGraph(Graph):
def upsert_vertex(self, vertex: Vertex):
"""Insert or update a vertex based on its ID."""
if vertex.vid in self._vs:
self._vs[vertex.vid].props.update(vertex.props)
if isinstance(self._vs[vertex.vid], IdVertex):
self._vs[vertex.vid] = vertex
else:
self._vs[vertex.vid].props.update(vertex.props)
else:
self._vs[vertex.vid] = vertex
@@ -265,9 +271,6 @@ class MemoryGraph(Graph):
def append_edge(self, edge: Edge):
"""Append an edge if it doesn't exist; requires edge label."""
if self.edge_label not in edge.props.keys():
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
sid = edge.sid
tid = edge.tid
@@ -275,8 +278,8 @@ class MemoryGraph(Graph):
return False
# init vertex index
self._vs.setdefault(sid, Vertex(sid))
self._vs.setdefault(tid, Vertex(tid))
self._vs.setdefault(sid, IdVertex(sid))
self._vs.setdefault(tid, IdVertex(tid))
# update edge index
self._oes[sid][tid].add(edge)
@@ -346,18 +349,19 @@ class MemoryGraph(Graph):
self.del_neighbor_edges(vid, Direction.BOTH)
self._vs.pop(vid, None)
def del_edges(self, sid: str, tid: str, **props):
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
if not props:
self._edge_count -= old_edge_cnt
self._oes[sid].pop(tid, None)
self._ies[tid].pop(sid, None)
return
def remove_matches(es):
return set(filter(lambda e: not e.has_props(**props), es))
return set(
filter(
lambda e: not (
(name == e.name if name else True) and e.has_props(**props)
),
es,
)
)
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
@@ -439,12 +443,10 @@ class MemoryGraph(Graph):
"schema": [
{
"type": "VERTEX",
"label": f"{self._vertex_label}",
"properties": [{"name": k} for k in self._vertex_prop_keys],
},
{
"type": "EDGE",
"label": f"{self._edge_label}",
"properties": [{"name": k} for k in self._edge_prop_keys],
},
]
@@ -452,14 +454,30 @@ class MemoryGraph(Graph):
def format(self) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
vs_str = "\n".join(v.format() for v in self.vertices())
es_str = "\n".join(
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
f"{e.format(self.edge_label)}"
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
f"{self.get_vertex(e.sid).format(concise=True)}"
f"{e.format()}"
f"{self.get_vertex(e.tid).format(concise=True)}"
for e in self.edges()
)
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
return (
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
if (vs_str or es_str)
else ""
)
def truncate(self):
"""Truncate graph."""
# clean metadata
self._vertex_prop_keys.clear()
self._edge_prop_keys.clear()
self._edge_count = 0
# clean data and index
self._vs.clear()
self._oes.clear()
self._ies.clear()
def graphviz(self, name="g"):
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
@@ -468,7 +486,7 @@ class MemoryGraph(Graph):
g.add_node(vertex.vid)
for edge in self.edges():
triplet = edge.triplet(self.edge_label)
triplet = edge.triplet()
g.add_edge(triplet[0], triplet[2], label=triplet[1])
digraph = nx.nx_agraph.to_agraph(g).to_string()

View File

@@ -1,9 +1,9 @@
"""Graph store base class."""
"""Memory graph store."""
import json
import logging
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
@@ -15,32 +15,51 @@ class MemoryGraphStoreConfig(GraphStoreConfig):
model_config = ConfigDict(arbitrary_types_allowed=True)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
)
class MemoryGraphStore(GraphStoreBase):
"""Memory graph store."""
def __init__(self, graph_store_config: MemoryGraphStoreConfig):
"""Initialize MemoryGraphStore with a memory graph."""
self._edge_name_key = graph_store_config.edge_name_key
self._graph = MemoryGraph(edge_label=self._edge_name_key)
self._graph_store_config = graph_store_config
self._graph = MemoryGraph()
def get_config(self):
"""Get the graph store config."""
return self._graph_store_config
def get_edge_type(self) -> str:
"""Get the edge type."""
raise NotImplementedError("Memory graph store does not have edge type")
def get_vertex_type(self) -> str:
"""Get the vertex type."""
raise NotImplementedError("Memory graph store does not have vertex type")
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert a triplet into the graph."""
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
self._graph.append_edge(Edge(sub, obj, rel))
def insert_graph(self, graph: Graph):
"""Add graph."""
for vertex in graph.vertices():
self._graph.upsert_vertex(vertex)
for edge in graph.edges():
self._graph.append_edge(edge)
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Retrieve triplets originating from a subject."""
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
return [(e.get_prop(self._edge_name_key), e.tid) for e in subgraph.edges()]
return [(e.name, e.tid) for e in subgraph.edges()]
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete a specific triplet from the graph."""
self._graph.del_edges(sub, obj, **{self._edge_name_key: rel})
self._graph.del_edges(sub, obj, rel)
def truncate(self):
"""Truncate graph."""
self._graph.truncate()
def drop(self):
"""Drop graph."""
@@ -50,7 +69,7 @@ class MemoryGraphStore(GraphStoreBase):
"""Return the graph schema as a JSON string."""
return json.dumps(self._graph.schema())
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Return self."""
if not limit:
return self._graph
@@ -79,3 +98,7 @@ class MemoryGraphStore(GraphStoreBase):
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
raise NotImplementedError("Query memory graph not allowed")
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
raise NotImplementedError("Stream query memory graph not allowed")

View File

@@ -1,10 +1,8 @@
"""Neo4j vector store."""
"""Neo4j store."""
import logging
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Graph, MemoryGraph
logger = logging.getLogger(__name__)
@@ -19,46 +17,3 @@ class Neo4jStore(GraphStoreBase):
"""Neo4j graph store."""
# todo: add neo4j implementation
def __init__(self, graph_store_config: Neo4jStoreConfig):
"""Initialize the Neo4jStore with connection details."""
pass
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert triplets."""
pass
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
return []
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplets."""
pass
def drop(self):
"""Drop graph."""
pass
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
return ""
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
return MemoryGraph()
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
return MemoryGraph()
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
return MemoryGraph()

View File

@@ -1,12 +1,14 @@
"""TuGraph vector store."""
"""TuGraph store."""
import base64
import json
import logging
import os
from typing import List, Optional, Tuple
from typing import Any, Generator, Iterator, List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
logger = logging.getLogger(__name__)
@@ -29,20 +31,24 @@ class TuGraphStoreConfig(GraphStoreConfig):
description="login username",
)
password: str = Field(
default="123456",
default="73@TuGraph",
description="login password",
)
vertex_type: str = Field(
default="entity",
description="The type of graph vertex, `entity` by default.",
description="The type of vertex, `entity` by default.",
)
edge_type: str = Field(
default="relation",
description="The type of graph edge, `relation` by default.",
description="The type of edge, `relation` by default.",
)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
plugin_names: List[str] = Field(
default=["leiden"],
description=(
"Plugins need to be loaded when initialize TuGraph, "
"code: https://github.com/TuGraph-family"
"/dbgpt-tugraph-plugins/tree/master/cpp"
),
)
@@ -51,20 +57,23 @@ class TuGraphStore(GraphStoreBase):
def __init__(self, config: TuGraphStoreConfig) -> None:
"""Initialize the TuGraphStore with connection details."""
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
self._node_label = (
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
self._config = config
self._host = os.getenv("TUGRAPH_HOST", config.host)
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self._summary_enabled = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.summary_enabled
)
self._edge_label = (
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
)
self.edge_name_key = (
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
)
self._graph_name = config.name
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
self.conn = TuGraphConnector.from_uri_db(
host=self._host,
port=self._port,
@@ -72,35 +81,197 @@ class TuGraphStore(GraphStoreBase):
pwd=self._password,
db_name=config.name,
)
self.conn.create_graph(graph_name=config.name)
self._create_graph(config.name)
def get_vertex_type(self) -> str:
"""Get the vertex type."""
return self._vertex_type
def get_edge_type(self) -> str:
"""Get the edge type."""
return self._edge_type
def _create_graph(self, graph_name: str):
self.conn.create_graph(graph_name=graph_name)
self._create_schema()
if self._summary_enabled:
self._upload_plugin()
def _check_label(self, elem_type: str):
result = self.conn.get_table_names()
if elem_type == "vertex":
return self._node_label in result["vertex_tables"]
return self._vertex_type in result["vertex_tables"]
if elem_type == "edge":
return self._edge_label in result["edge_tables"]
return self._edge_type in result["edge_tables"]
def _add_vertex_index(self, field_name):
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
self.conn.run(gql)
def _upload_plugin(self):
gql = "CALL db.plugin.listPlugin('CPP','v1')"
result = self.conn.run(gql)
result_names = [
json.loads(record["plugin_description"])["name"] for record in result
]
missing_plugins = [
name for name in self._plugin_names if name not in result_names
]
if len(missing_plugins):
for name in missing_plugins:
try:
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
get_plugin_binary_path,
)
except ImportError:
logger.error(
"dbgpt-tugraph-plugins is not installed, "
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
"https://pypi.org/simple"
)
plugin_path = get_plugin_binary_path("leiden")
with open(plugin_path, "rb") as f:
content = f.read()
content = base64.b64encode(content).decode()
gql = (
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
"'SO', '{name} Plugin', false, 'v1')"
)
self.conn.run(gql)
def _create_schema(self):
if not self._check_label("vertex"):
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
self.conn.run(create_vertex_gql)
if self._summary_enabled:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false],"
f"['_document_id',string,true],"
f"['_chunk_id',string,true],"
f"['_community_id',string,true],"
f"['description',string,true])"
)
self.conn.run(create_vertex_gql)
self._add_vertex_index("_community_id")
else:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false])"""
if self._summary_enabled:
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false],
["description",STRING,true])"""
self.conn.run(create_edge_gql)
def _format_query_data(self, data, white_prop_list: List[str]):
nodes_list = []
rels_list: List[Any] = []
_white_list = white_prop_list
from neo4j import graph
def get_filtered_properties(properties, white_list):
return {
key: value
for key, value in properties.items()
if (not key.startswith("_") and key not in ["id", "name"])
or key in white_list
}
def process_node(node: graph.Node):
node_id = node._properties.get("id")
node_name = node._properties.get("name")
node_properties = get_filtered_properties(node._properties, _white_list)
nodes_list.append(
{"id": node_id, "name": node_name, "properties": node_properties}
)
def process_relationship(rel: graph.Relationship):
name = rel._properties.get("name", "")
rel_nodes = rel.nodes
src_id = rel_nodes[0]._properties.get("id")
dst_id = rel_nodes[1]._properties.get("id")
for node in rel_nodes:
process_node(node)
edge_properties = get_filtered_properties(rel._properties, _white_list)
if not any(
existing_edge.get("name") == name
and existing_edge.get("src_id") == src_id
and existing_edge.get("dst_id") == dst_id
for existing_edge in rels_list
):
rels_list.append(
{
"src_id": src_id,
"dst_id": dst_id,
"name": name,
"properties": edge_properties,
}
)
def process_path(path: graph.Path):
for rel in path.relationships:
process_relationship(rel)
def process_other(value):
if not any(
existing_node.get("id") == "json_node" for existing_node in nodes_list
):
nodes_list.append(
{
"id": "json_node",
"name": "json_node",
"properties": {"description": value},
}
)
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
process_node(value)
elif isinstance(value, graph.Relationship):
process_relationship(value)
elif isinstance(value, graph.Path):
process_path(value)
else:
process_other(value)
nodes = [
Vertex(node["id"], node["name"], **node["properties"])
for node in nodes_list
]
rels = [
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
for edge in rels_list
]
return {"nodes": nodes, "edges": rels}
def get_config(self):
"""Get the graph store config."""
return self._config
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.conn.run(query)
@@ -117,16 +288,83 @@ class TuGraphStore(GraphStoreBase):
rel_escaped = escape_quotes(rel)
obj_escaped = escape_quotes(obj)
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
rel_query = (
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
node_query = f"""CALL db.upsertVertex(
'{self._vertex_type}',
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
edge_query = f"""CALL db.upsertEdge(
'{self._edge_type}',
{{type:"{self._vertex_type}",key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{{sid:"{subj_escaped}",
tid: "{obj_escaped}",
id:"{rel_escaped}",
name: "{rel_escaped}"}}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def insert_graph(self, graph: Graph) -> None:
"""Add graph."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")
nodes: Iterator[Vertex] = graph.vertices()
edges: Iterator[Edge] = graph.edges()
node_list = []
edge_list = []
def parser(node_list):
formatted_nodes = [
"{"
+ ", ".join(
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
for k, v in node.items()
)
+ "}"
for node in node_list
]
return f"""{', '.join(formatted_nodes)}"""
for node in nodes:
node_list.append(
{
"id": escape_quotes(node.vid),
"name": escape_quotes(node.name),
"description": escape_quotes(node.get_prop("description")) or "",
"_document_id": "0",
"_chunk_id": "0",
"_community_id": "0",
}
)
node_query = (
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
)
self.conn.run(query=subj_query)
self.conn.run(query=obj_query)
self.conn.run(query=rel_query)
for edge in edges:
edge_list.append(
{
"sid": escape_quotes(edge.sid),
"tid": escape_quotes(edge.tid),
"id": escape_quotes(edge.name),
"name": escape_quotes(edge.name),
"description": escape_quotes(edge.get_prop("description")),
}
)
edge_query = f"""CALL db.upsertEdge(
"{self._edge_type}",
{{type:"{self._vertex_type}", key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{parser(edge_list)}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def truncate(self):
"""Truncate Graph."""
gql = "MATCH (n) DELETE n"
self.conn.run(gql)
def drop(self):
"""Delete Graph."""
@@ -135,9 +373,9 @@ class TuGraphStore(GraphStoreBase):
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.conn.run(query=del_query)
@@ -148,11 +386,20 @@ class TuGraphStore(GraphStoreBase):
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
if not limit:
raise Exception("limit must be set")
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
graph_result = self.query(
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
white_list=["_community_id"],
)
all_graph = MemoryGraph()
for vertex in graph_result.vertices():
all_graph.upsert_vertex(vertex)
for edge in graph_result.edges():
all_graph.append_edge(edge)
return all_graph
def explore(
self,
@@ -161,8 +408,11 @@ class TuGraphStore(GraphStoreBase):
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
if not subs:
return MemoryGraph()
if fan is not None:
raise ValueError("Fan functionality is not supported at this time.")
else:
@@ -173,67 +423,88 @@ class TuGraphStore(GraphStoreBase):
limit_string = f"LIMIT {limit}"
if limit is None:
limit_string = ""
if direct.name == "OUT":
rel = f"-[r:{self._edge_type}*{depth_string}]->"
elif direct.name == "IN":
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
else:
rel = f"-[r:{self._edge_type}*{depth_string}]-"
query = (
f"MATCH p=(n:{self._node_label})"
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
f"MATCH p=(n:{self._vertex_type})"
f"{rel}(m:{self._vertex_type}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
def query(self, query: str, **args) -> MemoryGraph:
"""Execute a query on graph."""
def _format_paths(paths):
formatted_paths = []
for path in paths:
formatted_path = []
nodes = list(path["p"].nodes)
rels = list(path["p"].relationships)
for i in range(len(nodes)):
formatted_path.append(nodes[i]._properties["id"])
if i < len(rels):
formatted_path.append(rels[i]._properties["id"])
formatted_paths.append(formatted_path)
return formatted_paths
def _format_query_data(data):
node_ids_set = set()
rels_set = set()
from neo4j import graph
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
node_ids_set.add(node_id)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
rels_set.add((src_id, dst_id, prop_id))
elif isinstance(value, graph.Path):
formatted_paths = _format_paths(data)
for path in formatted_paths:
for i in range(0, len(path), 2):
node_ids_set.add(path[i])
if i + 2 < len(path):
rels_set.add((path[i], path[i + 2], path[i + 1]))
nodes = [Vertex(node_id) for node_id in node_ids_set]
rels = [
Edge(src_id, dst_id, label=prop_id)
for (src_id, dst_id, prop_id) in rels_set
]
return {"nodes": nodes, "edges": rels}
result = self.conn.run(query=query)
graph = _format_query_data(result)
white_list = args.get("white_list", [])
graph = self._format_query_data(result, white_list)
mg = MemoryGraph()
for vertex in graph["nodes"]:
mg.upsert_vertex(vertex)
for edge in graph["edges"]:
mg.append_edge(edge)
return mg
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute a stream query."""
from neo4j import graph
for record in self.conn.run_stream(query):
mg = MemoryGraph()
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties["description"]
vertex = Vertex(node_id, name=node_id, description=description)
mg.upsert_vertex(vertex)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
description = value._properties["description"]
edge = Edge(src_id, dst_id, name=prop_id, description=description)
mg.append_edge(edge)
elif isinstance(value, graph.Path):
nodes = list(record["p"].nodes)
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
formatted_path[i]["id"],
name=formatted_path[i]["id"],
description=formatted_path[i]["description"],
)
)
if i + 2 < len(formatted_path):
mg.append_edge(
Edge(
formatted_path[i]["id"],
formatted_path[i + 2]["id"],
name=formatted_path[i + 1]["id"],
description=formatted_path[i + 1]["description"],
)
)
else:
vertex = Vertex("json_node", name="json_node", description=value)
mg.upsert_vertex(vertex)
yield mg

View File

@@ -19,6 +19,10 @@ class KnowledgeGraphConfig(IndexStoreConfig):
class KnowledgeGraphBase(IndexStoreBase, ABC):
"""Knowledge graph base class."""
@abstractmethod
def get_config(self) -> KnowledgeGraphConfig:
"""Get the knowledge graph config."""
@abstractmethod
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Get graph data."""

View File

@@ -0,0 +1 @@
"""Community Module."""

View File

@@ -0,0 +1,73 @@
"""Define Classes about Community."""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import Graph
logger = logging.getLogger(__name__)
@dataclass
class Community:
"""Community class."""
id: str
data: Optional[Graph] = None
summary: Optional[str] = None
@dataclass
class CommunityTree:
"""Represents a community tree."""
class CommunityStoreAdapter(ABC):
"""Community Store Adapter."""
def __init__(self, graph_store: GraphStoreBase):
"""Initialize Community Store Adapter."""
self._graph_store = graph_store
@property
def graph_store(self) -> GraphStoreBase:
"""Get graph store."""
return self._graph_store
@abstractmethod
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery."""
@abstractmethod
async def get_community(self, community_id: str) -> Community:
"""Get community."""
class CommunityMetastore(ABC):
"""Community metastore class."""
@abstractmethod
def get(self, community_id: str) -> Community:
"""Get community."""
@abstractmethod
def list(self) -> List[Community]:
"""Get all communities."""
@abstractmethod
async def search(self, query: str) -> List[Community]:
"""Search communities relevant to query."""
@abstractmethod
async def save(self, communities: List[Community]):
"""Save communities."""
@abstractmethod
async def truncate(self):
"""Truncate all communities."""
@abstractmethod
def drop(self):
"""Drop community metastore."""

View File

@@ -0,0 +1,63 @@
"""Builtin Community metastore."""
import logging
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.storage.knowledge_graph.community.base import Community, CommunityMetastore
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class BuiltinCommunityMetastore(CommunityMetastore):
"""Builtin Community metastore."""
def __init__(
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
):
"""Initialize Community metastore."""
self._vector_store = vector_store
self._rdb_store = rdb_store
config = self._vector_store.get_config()
self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold
def get(self, community_id: str) -> Community:
"""Get community."""
raise NotImplementedError("Get community not allowed")
def list(self) -> List[Community]:
"""Get all communities."""
raise NotImplementedError("List communities not allowed")
async def search(self, query: str) -> List[Community]:
"""Search communities relevant to query."""
chunks = await self._vector_store.asimilar_search_with_scores(
query, self._topk, self._score_threshold
)
return [Community(id=chunk.chunk_id, summary=chunk.content) for chunk in chunks]
async def save(self, communities: List[Community]):
"""Save communities."""
chunks = [
Chunk(id=c.id, content=c.summary, metadata={"total": len(communities)})
for c in communities
]
await self._vector_store.aload_document_with_limit(
chunks, self._max_chunks_once_load, self._max_threads
)
logger.info(f"Save {len(communities)} communities")
async def truncate(self):
"""Truncate community metastore."""
self._vector_store.truncate()
def drop(self):
"""Drop community metastore."""
if self._vector_store.vector_name_exists():
self._vector_store.delete_vector_name(self._vector_space)

View File

@@ -0,0 +1,83 @@
"""Define the CommunityStore class."""
import logging
from typing import List
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
from dbgpt.storage.knowledge_graph.community.community_metastore import (
BuiltinCommunityMetastore,
)
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class CommunityStore:
"""CommunityStore Class."""
def __init__(
self,
community_store_adapter: CommunityStoreAdapter,
community_summarizer: CommunitySummarizer,
vector_store: VectorStoreBase,
):
"""Initialize the CommunityStore class."""
self._community_store_adapter = community_store_adapter
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
async def build_communities(self):
"""Discover communities."""
community_ids = await (self._community_store_adapter.discover_communities())
# summarize communities
communities = []
for community_id in community_ids:
community = await (
self._community_store_adapter.get_community(community_id)
)
graph = community.data.format()
if not graph:
break
community.summary = await (
self._community_summarizer.summarize(graph=graph)
)
communities.append(community)
logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
)
# truncate then save new summaries
await self._meta_store.truncate()
await self._meta_store.save(communities)
async def search_communities(self, query: str) -> List[Community]:
"""Search communities."""
return await self._meta_store.search(query)
def truncate(self):
"""Truncate community store."""
logger.info("Truncate community metastore")
self._meta_store.truncate()
logger.info("Truncate community summarizer")
self._community_summarizer.truncate()
logger.info("Truncate graph")
self._community_store_adapter.graph_store.truncate()
def drop(self):
"""Drop community store."""
logger.info("Remove community metastore")
self._meta_store.drop()
logger.info("Remove community summarizer")
self._community_summarizer.drop()
logger.info("Remove graph")
self._community_store_adapter.graph_store.drop()

View File

@@ -0,0 +1,30 @@
"""CommunityStoreAdapter factory."""
import logging
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
TuGraphCommunityStoreAdapter,
)
logger = logging.getLogger(__name__)
class CommunityStoreAdapterFactory:
"""Factory for community store adapter."""
@staticmethod
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
"""Create a CommunityStoreAdapter instance.
Args:
- graph_store_type: graph store type Memory, TuGraph, Neo4j
"""
if isinstance(graph_store, TuGraphStore):
return TuGraphCommunityStoreAdapter(graph_store)
else:
raise Exception(
"create community store adapter for %s failed",
graph_store.__class__.__name__,
)

View File

@@ -0,0 +1,52 @@
"""TuGraph Community Store Adapter."""
import json
import logging
from typing import List
from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
logger = logging.getLogger(__name__)
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
"""TuGraph Community Store Adapter."""
MAX_HIERARCHY_LEVEL = 3
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
mg = self._graph_store.query(
"CALL db.plugin.callPlugin"
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
)
result = mg.get_vertex("json_node").get_prop("description")
community_ids = json.loads(result)["community_id_list"]
logger.info(f"Discovered {len(community_ids)} communities.")
return community_ids
async def get_community(self, community_id: str) -> Community:
"""Get community."""
query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n"
)
edge_query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
f"[r:{self._graph_store.get_edge_type()}]-"
f"(m:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
)
all_vertex_graph = self._graph_store.aquery(query)
all_edge_graph = self._graph_store.aquery(edge_query)
all_graph = MemoryGraph()
for vertex in all_vertex_graph.vertices():
all_graph.upsert_vertex(vertex)
for edge in all_edge_graph.edges():
all_graph.append_edge(edge)
return Community(id=community_id, data=all_graph)

View File

@@ -0,0 +1,373 @@
"""Define the CommunitySummaryKnowledgeGraph."""
import logging
import os
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory
from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.factory import VectorStoreFactory
from dbgpt.storage.vector_store.filters import MetadataFilters
logger = logging.getLogger(__name__)
class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
"""Community summary knowledge graph config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
vector_store_type: str = Field(
default="Chroma", description="The type of vector store."
)
user: Optional[str] = Field(
default=None,
description="The user of vector store, if not set, will use the default user.",
)
password: Optional[str] = Field(
default=None,
description=(
"The password of vector store, if not set, will use the default password."
),
)
extract_topk: int = Field(
default=5,
description="Topk of knowledge graph extract",
)
extract_score_threshold: float = Field(
default=0.3,
description="Recall score of knowledge graph extract",
)
community_topk: int = Field(
default=50,
description="Topk of community search in knowledge graph",
)
community_score_threshold: float = Field(
default=0.0,
description="Recall score of community search in knowledge graph",
)
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
"""Community summary knowledge graph class."""
def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
"""Initialize community summary knowledge graph class."""
super().__init__(config)
self._config = config
self._vector_store_type = os.getenv(
"VECTOR_STORE_TYPE", config.vector_store_type
)
self._extract_topk = int(
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
)
self._extract_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
config.extract_score_threshold,
)
)
self._community_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE", config.community_topk
)
)
self._community_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE",
config.community_score_threshold,
)
)
def extractor_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
cfg.embedding_fn = config.embedding_fn
cfg.max_chunks_once_load = config.max_chunks_once_load
cfg.max_threads = config.max_threads
cfg.user = config.user
cfg.password = config.password
cfg.topk = self._extract_topk
cfg.score_threshold = self._extract_score_threshold
self._graph_extractor = GraphExtractor(
self._llm_client,
self._model_name,
VectorStoreFactory.create(
self._vector_store_type,
config.name + "_CHUNK_HISTORY",
extractor_configure,
),
)
def community_store_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
cfg.embedding_fn = config.embedding_fn
cfg.max_chunks_once_load = config.max_chunks_once_load
cfg.max_threads = config.max_threads
cfg.user = config.user
cfg.password = config.password
cfg.topk = self._community_topk
cfg.score_threshold = self._community_score_threshold
self._community_store = CommunityStore(
CommunityStoreAdapterFactory.create(self._graph_store),
CommunitySummarizer(self._llm_client, self._model_name),
VectorStoreFactory.create(
self._vector_store_type,
config.name + "_COMMUNITY_SUMMARY",
community_store_configure,
),
)
def get_config(self) -> BuiltinKnowledgeGraphConfig:
"""Get the knowledge graph config."""
return self._config
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph."""
# todo add doc node
for chunk in chunks:
# todo add chunk node
# todo add relation doc-chunk
# extract graphs and save
graphs = await self._graph_extractor.extract(chunk.content)
for graph in graphs:
self._graph_store.insert_graph(graph)
# build communities and save
await self._community_store.build_communities()
return [chunk.chunk_id for chunk in chunks]
async def asimilar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve relevant community summaries."""
# global search: retrieve relevant community summaries
communities = await self._community_store.search_communities(text)
summaries = [
f"Section {i + 1}:\n{community.summary}"
for i, community in enumerate(communities)
]
context = "\n".join(summaries) if summaries else ""
# local search: extract keywords and explore subgraph
keywords = await self._keyword_extractor.extract(text)
subgraph = self._graph_store.explore(keywords, limit=topk).format()
logger.info(f"Search subgraph from {len(keywords)} keywords")
if not summaries and not subgraph:
return []
# merge search results into context
content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph)
return [Chunk(content=content)]
def truncate(self) -> List[str]:
"""Truncate knowledge graph."""
logger.info("Truncate community store")
self._community_store.truncate()
logger.info("Truncate keyword extractor")
self._keyword_extractor.truncate()
logger.info("Truncate triplet extractor")
self._graph_extractor.truncate()
return [self._config.name]
def delete_vector_name(self, index_name: str):
"""Delete knowledge graph."""
logger.info("Drop community store")
self._community_store.drop()
logger.info("Drop keyword extractor")
self._keyword_extractor.drop()
logger.info("Drop triplet extractor")
self._graph_extractor.drop()
HYBRID_SEARCH_PT_CN = (
"## 角色\n"
"你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,"
"准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。"
"\n"
"## 技能\n"
"### 技能 1: 上下文理解\n"
"- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n"
"- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n"
"- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。"
"### 技能 2: 知识图谱理解\n"
"- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息"
"和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n"
"```"
"* 实体信息格式:\n"
"- (实体名)\n"
"- (实体名:实体描述)\n"
"- (实体名:实体属性表)\n"
"- (文本块ID:文档块内容)\n"
"- (目录ID:目录名)\n"
"- (文档ID:文档名称)\n"
"\n"
"* 关系信息的格式:\n"
"- (来源实体名)-[关系名]->(目标实体名)\n"
"- (来源实体名)-[关系名:关系描述]->(目标实体名)\n"
"- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n"
"- (文本块实体)-[包含]->(实体名)\n"
"- (目录ID)-[包含]->(文本块实体)\n"
"- (目录ID)-[包含]->(子目录ID)\n"
"- (文档ID)-[包含]->(文本块实体)\n"
"- (文档ID)-[包含]->(目录ID)\n"
"```"
"- 正确地将关系信息中的实体名/ID与实体信息关联还原出图结构。"
"- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n"
"\n"
"## 约束条件\n"
"- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n"
"- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。"
"- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n"
"- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n"
"- 避免使用停用词和过于常见的词汇。\n"
"\n"
"## 参考案例\n"
"```\n"
"[上下文]:\n"
"Section 1:\n"
"菲尔・贾伯的大儿子叫雅各布・贾伯。\n"
"Section 2:\n"
"菲尔・贾伯的小儿子叫比尔・贾伯。\n"
"[知识图谱]:\n"
"Entities:\n"
"(菲尔・贾伯#菲尔兹咖啡创始人)\n"
"(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n"
"(雅各布・贾伯#菲尔・贾伯的儿子)\n"
"(美国多地#菲尔兹咖啡的扩展地区)\n"
"\n"
"Relationships:\n"
"(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n"
"(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n"
"(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n"
"(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n"
"(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n"
"```\n"
"\n"
"----\n"
"\n"
"接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n"
"\n"
"[上下文]:\n"
"{context}\n"
"\n"
"[知识图谱]:\n"
"{graph}\n"
"\n"
)
HYBRID_SEARCH_PT_EN = (
"## Role\n"
"You excel at combining the information provided in the [Context] with "
"information from the [KnowledgeGraph] to accurately and appropriately "
"answer user questions, ensuring that you do not output information "
"unrelated to the context and knowledge graph.\n"
"\n"
"## Skills\n"
"### Skill 1: Context Understanding\n"
"- Accurately understand the information provided in the [Context], "
"which may be divided into several sections.\n"
"- Each section in the context will start with [Section] "
"and may be numbered as needed.\n"
"- The context provides a summary description most relevant to the users "
"question, and it should be used wisely."
"### Skill 2: Knowledge Graph Understanding\n"
"- Accurately identify entity information in the [Entities:] section and "
"relationship information in the [Relationships:] section "
"of the [KnowledgeGraph]. The general format for entity "
"and relationship information is:\n"
"```"
"* Entity Information Format:\n"
"- (entity_name)\n"
"- (entity_name: entity_description)\n"
"- (entity_name: entity_property_map)\n"
"- (chunk_id: chunk_content)\n"
"- (catalog_id: catalog_name)\n"
"- (document_id: document_name)\n"
"\n"
"* Relationship Information Format:\n"
"- (source_entity_name)-[relationship_name]->(target_entity_name)\n"
"- (source_entity_name)-[relationship_name: relationship_description]->"
"(target_entity_name)\n"
"- (source_entity_name)-[relationship_name: relationship_property_map]->"
"(target_entity_name)\n"
"- (chunk_id)-[Contains]->(entity_name)\n"
"- (catalog_id)-[Contains]->(chunk_id)\n"
"- (catalog_id)-[Contains]->(sub_catalog_id)\n"
"- (document_id)-[Contains]->(chunk_id)\n"
"- (document_id)-[Contains]->(catalog_id)\n"
"```"
"- Correctly associate entity names/IDs in the relationship information "
"with entity information to restore the graph structure."
"- Use the information expressed by the graph structure as detailed "
"context for the user's query to assist in generating better answers.\n"
"\n"
"## Constraints\n"
"- Don't describe your thought process in the answer, provide the answer "
"to the user's question directly without generating irrelevant information."
"- If the [KnowledgeGraph] does not provide information, you should answer "
"the question based on the information provided in the [Context]."
"- Ensure to write in the third person, responding to questions from "
"an objective perspective based on the information combined from the "
"[Context] and the [KnowledgeGraph].\n"
"- If the provided information is contradictory, resolve the "
"contradictions and provide a single, coherent description.\n"
"- Avoid using stop words and overly common vocabulary.\n"
"\n"
"## Reference Example\n"
"```\n"
"[Context]:\n"
"Section 1:\n"
"Phil Schiller's eldest son is Jacob Schiller.\n"
"Section 2:\n"
"Phil Schiller's youngest son is Bill Schiller.\n"
"[KnowledgeGraph]:\n"
"Entities:\n"
"(Phil Jaber#Founder of Philz Coffee)\n"
"(Philz Coffee#Coffee brand founded in Berkeley, California)\n"
"(Jacob Jaber#Son of Phil Jaber)\n"
"(Multiple locations in the USA#Expansion regions of Philz Coffee)\n"
"\n"
"Relationships:\n"
"(Phil Jaber#Created#Philz Coffee"
"#Founded in Berkeley, California in 1978)\n"
"(Philz Coffee#Located in#Berkeley, California"
"#Founding location of Philz Coffee)\n"
"(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n"
"(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n"
"(Philz Coffee#Expanded to#Multiple locations in the USA"
"#Expansion regions of Philz Coffee)\n"
"```\n"
"\n"
"----\n"
"\n"
"The following information from the [Context] and [KnowledgeGraph] can "
"help you better answer user questions.\n"
"\n"
"[Context]:\n"
"{context}\n"
"\n"
"[KnowledgeGraph]:\n"
"{graph}\n"
"\n"
)

View File

@@ -36,8 +36,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def __init__(self, config: BuiltinKnowledgeGraphConfig):
"""Create builtin knowledge graph instance."""
self._config = config
super().__init__()
self._config = config
self._llm_client = config.llm_client
if not self._llm_client:
raise ValueError("No llm client provided.")
@@ -45,17 +46,19 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
self._model_name = config.model_name
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
self._graph_store_type = (
os.getenv("GRAPH_STORE_TYPE", "TuGraph") or config.graph_store_type
)
self._graph_store = self.__init_graph_store(config)
def __init_graph_store(self, config) -> GraphStoreBase:
def configure(cfg: GraphStoreConfig):
cfg.name = self._config.name
cfg.embedding_fn = self._config.embedding_fn
cfg.name = config.name
cfg.embedding_fn = config.embedding_fn
self._graph_store: GraphStoreBase = GraphStoreFactory.create(
self._graph_store_type, configure
)
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type
return GraphStoreFactory.create(graph_store_type, configure)
def get_config(self) -> BuiltinKnowledgeGraphConfig:
"""Get the knowledge graph config."""
return self._config
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist triplets to graph store."""
@@ -113,35 +116,59 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
# extract keywords and explore graph store
keywords = await self._keyword_extractor.extract(text)
subgraph = self._graph_store.explore(keywords, limit=topk)
subgraph = self._graph_store.explore(keywords, limit=topk).format()
logger.info(f"Search subgraph from {len(keywords)} keywords")
if not subgraph:
return []
content = (
"The following vertices and edges data after [Subgraph Data] "
"are retrieved from the knowledge graph based on the keywords:\n"
f"Keywords:\n{','.join(keywords)}\n"
"The following entities and relationships provided after "
"[Subgraph] are retrieved from the knowledge graph "
"based on the keywords:\n"
f"\"{','.join(keywords)}\".\n"
"---------------------\n"
"You can refer to the sample vertices and edges to understand "
"the real knowledge graph data provided by [Subgraph Data].\n"
"Sample vertices:\n"
"The following examples after [Entities] and [Relationships] that "
"can help you understand the data format of the knowledge graph, "
"but do not use them in the answer.\n"
"[Entities]:\n"
"(alice)\n"
"(bob:{age:28})\n"
'(carry:{age:18;role:"teacher"})\n\n'
"Sample edges:\n"
"[Relationships]:\n"
"(alice)-[reward]->(alice)\n"
'(alice)-[notify:{method:"email"}]->'
'(carry:{age:18;role:"teacher"})\n'
'(bob:{age:28})-[teach:{course:"math";hour:180}]->(alice)\n'
"---------------------\n"
f"Subgraph Data:\n{subgraph.format()}\n"
f"[Subgraph]:\n{subgraph}\n"
)
return [Chunk(content=content, metadata=subgraph.schema())]
return [Chunk(content=content)]
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Query graph."""
return self._graph_store.get_full_graph(limit)
def truncate(self) -> List[str]:
"""Truncate knowledge graph."""
logger.info(f"Truncate graph {self._config.name}")
self._graph_store.truncate()
logger.info("Truncate keyword extractor")
self._keyword_extractor.truncate()
logger.info("Truncate triplet extractor")
self._triplet_extractor.truncate()
return [self._config.name]
def delete_vector_name(self, index_name: str):
"""Delete vector name."""
logger.info(f"Remove graph index {index_name}")
logger.info(f"Drop graph {index_name}")
self._graph_store.drop()
logger.info("Drop keyword extractor")
self._keyword_extractor.drop()
logger.info("Drop triplet extractor")
self._triplet_extractor.drop()

View File

@@ -1,12 +1,8 @@
"""OpenSPG class."""
import logging
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict
from dbgpt.core import Chunk
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
logger = logging.getLogger(__name__)
@@ -21,29 +17,3 @@ class OpenSPG(KnowledgeGraphBase):
"""OpenSPG class."""
# todo: add OpenSPG implementation
def __init__(self, config: OpenSPGConfig):
"""Initialize the OpenSPG with config details."""
pass
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document."""
return []
def similar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar with scores."""
return []
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Query graph."""
return MemoryGraph()
def delete_vector_name(self, index_name: str):
"""Delete vector name."""
pass

View File

@@ -56,6 +56,15 @@ def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
return BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig
def _import_community_summary_knowledge_graph() -> Tuple[Type, Type]:
from dbgpt.storage.knowledge_graph.community_summary import (
CommunitySummaryKnowledgeGraph,
CommunitySummaryKnowledgeGraphConfig,
)
return CommunitySummaryKnowledgeGraph, CommunitySummaryKnowledgeGraphConfig
def _import_openspg() -> Tuple[Type, Type]:
from dbgpt.storage.knowledge_graph.open_spg import OpenSPG, OpenSPGConfig
@@ -86,6 +95,8 @@ def __getattr__(name: str) -> Tuple[Type, Type]:
return _import_elastic()
elif name == "KnowledgeGraph":
return _import_builtin_knowledge_graph()
elif name == "CommunitySummaryKnowledgeGraph":
return _import_community_summary_knowledge_graph()
elif name == "OpenSPG":
return _import_openspg()
elif name == "FullText":
@@ -103,7 +114,7 @@ __vector_store__ = [
"ElasticSearch",
]
__knowledge_graph__ = ["KnowledgeGraph", "OpenSPG"]
__knowledge_graph__ = ["KnowledgeGraph", "CommunitySummaryKnowledgeGraph", "OpenSPG"]
__document_store__ = ["FullText"]

View File

@@ -99,6 +99,14 @@ class VectorStoreConfig(IndexStoreConfig):
"The password of vector store, if not set, will use the default password."
),
)
topk: int = Field(
default=5,
description="Topk of vector search",
)
score_threshold: float = Field(
default=0.3,
description="Recall score of vector search",
)
class VectorStoreBase(IndexStoreBase, ABC):
@@ -108,6 +116,10 @@ class VectorStoreBase(IndexStoreBase, ABC):
"""Initialize vector store."""
super().__init__(executor)
@abstractmethod
def get_config(self) -> VectorStoreConfig:
"""Get the vector store config."""
def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
@@ -126,7 +138,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
metadata=chunk.metadata,
content=chunk.content,
score=chunk.score,
chunk_id=str(id),
chunk_id=chunk.chunk_id,
)
for chunk in chunks
if chunk.score >= score_threshold

View File

@@ -63,6 +63,8 @@ class ChromaStore(VectorStoreBase):
vector_store_config(ChromaVectorConfig): vector store config.
"""
super().__init__()
self._vector_store_config = vector_store_config
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")
@@ -89,6 +91,10 @@ class ChromaStore(VectorStoreBase):
metadata=collection_metadata,
)
def get_config(self) -> ChromaVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
@@ -100,10 +106,16 @@ class ChromaStore(VectorStoreBase):
filters=filters,
)
return [
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
Chunk(
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=0.0,
chunk_id=chroma_result[2],
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["ids"][0],
)
]
@@ -134,12 +146,14 @@ class ChromaStore(VectorStoreBase):
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=(1 - chroma_result[2]),
chunk_id=chroma_result[3],
)
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["distances"][0],
chroma_results["ids"][0],
)
]
return self.filter_by_score_threshold(chunks, score_threshold)
@@ -181,6 +195,20 @@ class ChromaStore(VectorStoreBase):
if len(ids) > 0:
self._collection.delete(ids=ids)
def truncate(self) -> List[str]:
"""Truncate data index_name."""
logger.info(f"begin truncate chroma collection:{self._collection.name}")
results = self._collection.get()
ids = results.get("ids")
if ids:
self._collection.delete(ids=ids)
logger.info(
f"truncate chroma collection {self._collection.name} "
f"{len(ids)} chunks success"
)
return ids
return []
def convert_metadata_filters(
self,
filters: MetadataFilters,

View File

@@ -126,6 +126,8 @@ class ElasticStore(VectorStoreBase):
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
"""
super().__init__()
self._vector_store_config = vector_store_config
connect_kwargs = {}
elasticsearch_vector_config = vector_store_config.dict()
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
@@ -234,6 +236,10 @@ class ElasticStore(VectorStoreBase):
except Exception as e:
logger.error(f"ElasticSearch connection failed: {e}")
def get_config(self) -> ElasticsearchVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def load_document(
self,
chunks: List[Chunk],

View File

@@ -0,0 +1,44 @@
"""Vector store factory."""
import logging
from typing import Tuple, Type
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
class VectorStoreFactory:
"""Factory for vector store."""
@staticmethod
def create(
vector_store_type: str, vector_space_name: str, vector_store_configure=None
) -> VectorStoreBase:
"""Create a VectorStore instance.
Args:
- vector_store_type: vector store type Chroma, Milvus, etc.
- vector_store_config: vector store config
"""
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
try:
config = cfg_cls()
if vector_store_configure:
vector_store_configure(vector_space_name, config)
return store_cls(config)
except Exception as e:
logger.error("create vector store failed: %s", e)
raise e
@staticmethod
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
for t in vector_store.__vector_store__:
if t.lower() == vector_store_type.lower():
store_cls, cfg_cls = getattr(vector_store, t)
if issubclass(store_cls, VectorStoreBase) and issubclass(
cfg_cls, VectorStoreConfig
):
return store_cls, cfg_cls
raise Exception(f"Vector store {vector_store_type} not supported")

View File

@@ -150,6 +150,8 @@ class MilvusStore(VectorStoreBase):
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
"""
super().__init__()
self._vector_store_config = vector_store_config
try:
from pymilvus import connections
except ImportError:
@@ -363,6 +365,10 @@ class MilvusStore(VectorStoreBase):
return res.primary_keys
def get_config(self) -> MilvusVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database."""
batch_size = 500

View File

@@ -718,6 +718,8 @@ class OceanBaseStore(VectorStoreBase):
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
self._vector_store_config = vector_store_config
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
vector_store_config = vector_store_config.dict()
@@ -760,6 +762,10 @@ class OceanBaseStore(VectorStoreBase):
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
)
def get_config(self) -> OceanBaseConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
) -> List[Chunk]:

View File

@@ -64,6 +64,8 @@ class PGVectorStore(VectorStoreBase):
"Please install the `langchain` package to use the PGVector."
)
super().__init__()
self._vector_store_config = vector_store_config
self.connection_string = vector_store_config.connection_string
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
@@ -74,6 +76,10 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string,
)
def get_config(self) -> PGVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:

View File

@@ -69,6 +69,8 @@ class WeaviateStore(VectorStoreBase):
"Please install it with `pip install weaviate-client`."
)
super().__init__()
self._vector_store_config = vector_store_config
self.weaviate_url = vector_store_config.weaviate_url
self.embedding = vector_store_config.embedding_fn
self.vector_name = vector_store_config.name
@@ -78,6 +80,10 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url)
def get_config(self) -> WeaviateVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: