mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 13:00:02 +00:00
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
@@ -56,6 +56,15 @@ def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
|
||||
return BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig
|
||||
|
||||
|
||||
def _import_community_summary_knowledge_graph() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.knowledge_graph.community_summary import (
|
||||
CommunitySummaryKnowledgeGraph,
|
||||
CommunitySummaryKnowledgeGraphConfig,
|
||||
)
|
||||
|
||||
return CommunitySummaryKnowledgeGraph, CommunitySummaryKnowledgeGraphConfig
|
||||
|
||||
|
||||
def _import_openspg() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.knowledge_graph.open_spg import OpenSPG, OpenSPGConfig
|
||||
|
||||
@@ -86,6 +95,8 @@ def __getattr__(name: str) -> Tuple[Type, Type]:
|
||||
return _import_elastic()
|
||||
elif name == "KnowledgeGraph":
|
||||
return _import_builtin_knowledge_graph()
|
||||
elif name == "CommunitySummaryKnowledgeGraph":
|
||||
return _import_community_summary_knowledge_graph()
|
||||
elif name == "OpenSPG":
|
||||
return _import_openspg()
|
||||
elif name == "FullText":
|
||||
@@ -103,7 +114,7 @@ __vector_store__ = [
|
||||
"ElasticSearch",
|
||||
]
|
||||
|
||||
__knowledge_graph__ = ["KnowledgeGraph", "OpenSPG"]
|
||||
__knowledge_graph__ = ["KnowledgeGraph", "CommunitySummaryKnowledgeGraph", "OpenSPG"]
|
||||
|
||||
__document_store__ = ["FullText"]
|
||||
|
||||
|
@@ -99,6 +99,14 @@ class VectorStoreConfig(IndexStoreConfig):
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of vector search",
|
||||
)
|
||||
score_threshold: float = Field(
|
||||
default=0.3,
|
||||
description="Recall score of vector search",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(IndexStoreBase, ABC):
|
||||
@@ -108,6 +116,10 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
"""Initialize vector store."""
|
||||
super().__init__(executor)
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> VectorStoreConfig:
|
||||
"""Get the vector store config."""
|
||||
|
||||
def filter_by_score_threshold(
|
||||
self, chunks: List[Chunk], score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
@@ -126,7 +138,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
metadata=chunk.metadata,
|
||||
content=chunk.content,
|
||||
score=chunk.score,
|
||||
chunk_id=str(id),
|
||||
chunk_id=chunk.chunk_id,
|
||||
)
|
||||
for chunk in chunks
|
||||
if chunk.score >= score_threshold
|
||||
|
@@ -63,6 +63,8 @@ class ChromaStore(VectorStoreBase):
|
||||
vector_store_config(ChromaVectorConfig): vector store config.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
@@ -89,6 +91,10 @@ class ChromaStore(VectorStoreBase):
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
|
||||
def get_config(self) -> ChromaVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
@@ -100,10 +106,16 @@ class ChromaStore(VectorStoreBase):
|
||||
filters=filters,
|
||||
)
|
||||
return [
|
||||
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
|
||||
Chunk(
|
||||
content=chroma_result[0],
|
||||
metadata=chroma_result[1] or {},
|
||||
score=0.0,
|
||||
chunk_id=chroma_result[2],
|
||||
)
|
||||
for chroma_result in zip(
|
||||
chroma_results["documents"][0],
|
||||
chroma_results["metadatas"][0],
|
||||
chroma_results["ids"][0],
|
||||
)
|
||||
]
|
||||
|
||||
@@ -134,12 +146,14 @@ class ChromaStore(VectorStoreBase):
|
||||
content=chroma_result[0],
|
||||
metadata=chroma_result[1] or {},
|
||||
score=(1 - chroma_result[2]),
|
||||
chunk_id=chroma_result[3],
|
||||
)
|
||||
)
|
||||
for chroma_result in zip(
|
||||
chroma_results["documents"][0],
|
||||
chroma_results["metadatas"][0],
|
||||
chroma_results["distances"][0],
|
||||
chroma_results["ids"][0],
|
||||
)
|
||||
]
|
||||
return self.filter_by_score_threshold(chunks, score_threshold)
|
||||
@@ -181,6 +195,20 @@ class ChromaStore(VectorStoreBase):
|
||||
if len(ids) > 0:
|
||||
self._collection.delete(ids=ids)
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate data index_name."""
|
||||
logger.info(f"begin truncate chroma collection:{self._collection.name}")
|
||||
results = self._collection.get()
|
||||
ids = results.get("ids")
|
||||
if ids:
|
||||
self._collection.delete(ids=ids)
|
||||
logger.info(
|
||||
f"truncate chroma collection {self._collection.name} "
|
||||
f"{len(ids)} chunks success"
|
||||
)
|
||||
return ids
|
||||
return []
|
||||
|
||||
def convert_metadata_filters(
|
||||
self,
|
||||
filters: MetadataFilters,
|
||||
|
@@ -126,6 +126,8 @@ class ElasticStore(VectorStoreBase):
|
||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
connect_kwargs = {}
|
||||
elasticsearch_vector_config = vector_store_config.dict()
|
||||
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
||||
@@ -234,6 +236,10 @@ class ElasticStore(VectorStoreBase):
|
||||
except Exception as e:
|
||||
logger.error(f"ElasticSearch connection failed: {e}")
|
||||
|
||||
def get_config(self) -> ElasticsearchVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def load_document(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
|
44
dbgpt/storage/vector_store/factory.py
Normal file
44
dbgpt/storage/vector_store/factory.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Vector store factory."""
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""Factory for vector store."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
vector_store_type: str, vector_space_name: str, vector_store_configure=None
|
||||
) -> VectorStoreBase:
|
||||
"""Create a VectorStore instance.
|
||||
|
||||
Args:
|
||||
- vector_store_type: vector store type Chroma, Milvus, etc.
|
||||
- vector_store_config: vector store config
|
||||
"""
|
||||
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
|
||||
|
||||
try:
|
||||
config = cfg_cls()
|
||||
if vector_store_configure:
|
||||
vector_store_configure(vector_space_name, config)
|
||||
return store_cls(config)
|
||||
except Exception as e:
|
||||
logger.error("create vector store failed: %s", e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
||||
for t in vector_store.__vector_store__:
|
||||
if t.lower() == vector_store_type.lower():
|
||||
store_cls, cfg_cls = getattr(vector_store, t)
|
||||
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
||||
cfg_cls, VectorStoreConfig
|
||||
):
|
||||
return store_cls, cfg_cls
|
||||
raise Exception(f"Vector store {vector_store_type} not supported")
|
@@ -150,6 +150,8 @@ class MilvusStore(VectorStoreBase):
|
||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||
"""
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
try:
|
||||
from pymilvus import connections
|
||||
except ImportError:
|
||||
@@ -363,6 +365,10 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
return res.primary_keys
|
||||
|
||||
def get_config(self) -> MilvusVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in vector database."""
|
||||
batch_size = 500
|
||||
|
@@ -718,6 +718,8 @@ class OceanBaseStore(VectorStoreBase):
|
||||
if vector_store_config.embedding_fn is None:
|
||||
raise ValueError("embedding_fn is required for OceanBaseStore")
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
vector_store_config = vector_store_config.dict()
|
||||
@@ -760,6 +762,10 @@ class OceanBaseStore(VectorStoreBase):
|
||||
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
|
||||
)
|
||||
|
||||
def get_config(self) -> OceanBaseConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
|
||||
) -> List[Chunk]:
|
||||
|
@@ -64,6 +64,8 @@ class PGVectorStore(VectorStoreBase):
|
||||
"Please install the `langchain` package to use the PGVector."
|
||||
)
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
@@ -74,6 +76,10 @@ class PGVectorStore(VectorStoreBase):
|
||||
connection_string=self.connection_string,
|
||||
)
|
||||
|
||||
def get_config(self) -> PGVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
|
@@ -69,6 +69,8 @@ class WeaviateStore(VectorStoreBase):
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
super().__init__()
|
||||
self._vector_store_config = vector_store_config
|
||||
|
||||
self.weaviate_url = vector_store_config.weaviate_url
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.vector_name = vector_store_config.name
|
||||
@@ -78,6 +80,10 @@ class WeaviateStore(VectorStoreBase):
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def get_config(self) -> WeaviateVectorConfig:
|
||||
"""Get the vector store config."""
|
||||
return self._vector_store_config
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
|
Reference in New Issue
Block a user