feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)

Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
M1n9X
2024-08-30 21:59:44 +08:00
committed by GitHub
parent 471689ba20
commit 759f7d99cc
59 changed files with 29316 additions and 411 deletions

View File

@@ -56,6 +56,15 @@ def _import_builtin_knowledge_graph() -> Tuple[Type, Type]:
return BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig
def _import_community_summary_knowledge_graph() -> Tuple[Type, Type]:
from dbgpt.storage.knowledge_graph.community_summary import (
CommunitySummaryKnowledgeGraph,
CommunitySummaryKnowledgeGraphConfig,
)
return CommunitySummaryKnowledgeGraph, CommunitySummaryKnowledgeGraphConfig
def _import_openspg() -> Tuple[Type, Type]:
from dbgpt.storage.knowledge_graph.open_spg import OpenSPG, OpenSPGConfig
@@ -86,6 +95,8 @@ def __getattr__(name: str) -> Tuple[Type, Type]:
return _import_elastic()
elif name == "KnowledgeGraph":
return _import_builtin_knowledge_graph()
elif name == "CommunitySummaryKnowledgeGraph":
return _import_community_summary_knowledge_graph()
elif name == "OpenSPG":
return _import_openspg()
elif name == "FullText":
@@ -103,7 +114,7 @@ __vector_store__ = [
"ElasticSearch",
]
__knowledge_graph__ = ["KnowledgeGraph", "OpenSPG"]
__knowledge_graph__ = ["KnowledgeGraph", "CommunitySummaryKnowledgeGraph", "OpenSPG"]
__document_store__ = ["FullText"]

View File

@@ -99,6 +99,14 @@ class VectorStoreConfig(IndexStoreConfig):
"The password of vector store, if not set, will use the default password."
),
)
topk: int = Field(
default=5,
description="Topk of vector search",
)
score_threshold: float = Field(
default=0.3,
description="Recall score of vector search",
)
class VectorStoreBase(IndexStoreBase, ABC):
@@ -108,6 +116,10 @@ class VectorStoreBase(IndexStoreBase, ABC):
"""Initialize vector store."""
super().__init__(executor)
@abstractmethod
def get_config(self) -> VectorStoreConfig:
"""Get the vector store config."""
def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
@@ -126,7 +138,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
metadata=chunk.metadata,
content=chunk.content,
score=chunk.score,
chunk_id=str(id),
chunk_id=chunk.chunk_id,
)
for chunk in chunks
if chunk.score >= score_threshold

View File

@@ -63,6 +63,8 @@ class ChromaStore(VectorStoreBase):
vector_store_config(ChromaVectorConfig): vector store config.
"""
super().__init__()
self._vector_store_config = vector_store_config
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")
@@ -89,6 +91,10 @@ class ChromaStore(VectorStoreBase):
metadata=collection_metadata,
)
def get_config(self) -> ChromaVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
@@ -100,10 +106,16 @@ class ChromaStore(VectorStoreBase):
filters=filters,
)
return [
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
Chunk(
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=0.0,
chunk_id=chroma_result[2],
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["ids"][0],
)
]
@@ -134,12 +146,14 @@ class ChromaStore(VectorStoreBase):
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=(1 - chroma_result[2]),
chunk_id=chroma_result[3],
)
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["distances"][0],
chroma_results["ids"][0],
)
]
return self.filter_by_score_threshold(chunks, score_threshold)
@@ -181,6 +195,20 @@ class ChromaStore(VectorStoreBase):
if len(ids) > 0:
self._collection.delete(ids=ids)
def truncate(self) -> List[str]:
"""Truncate data index_name."""
logger.info(f"begin truncate chroma collection:{self._collection.name}")
results = self._collection.get()
ids = results.get("ids")
if ids:
self._collection.delete(ids=ids)
logger.info(
f"truncate chroma collection {self._collection.name} "
f"{len(ids)} chunks success"
)
return ids
return []
def convert_metadata_filters(
self,
filters: MetadataFilters,

View File

@@ -126,6 +126,8 @@ class ElasticStore(VectorStoreBase):
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
"""
super().__init__()
self._vector_store_config = vector_store_config
connect_kwargs = {}
elasticsearch_vector_config = vector_store_config.dict()
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
@@ -234,6 +236,10 @@ class ElasticStore(VectorStoreBase):
except Exception as e:
logger.error(f"ElasticSearch connection failed: {e}")
def get_config(self) -> ElasticsearchVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def load_document(
self,
chunks: List[Chunk],

View File

@@ -0,0 +1,44 @@
"""Vector store factory."""
import logging
from typing import Tuple, Type
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
class VectorStoreFactory:
"""Factory for vector store."""
@staticmethod
def create(
vector_store_type: str, vector_space_name: str, vector_store_configure=None
) -> VectorStoreBase:
"""Create a VectorStore instance.
Args:
- vector_store_type: vector store type Chroma, Milvus, etc.
- vector_store_config: vector store config
"""
store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type)
try:
config = cfg_cls()
if vector_store_configure:
vector_store_configure(vector_space_name, config)
return store_cls(config)
except Exception as e:
logger.error("create vector store failed: %s", e)
raise e
@staticmethod
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
for t in vector_store.__vector_store__:
if t.lower() == vector_store_type.lower():
store_cls, cfg_cls = getattr(vector_store, t)
if issubclass(store_cls, VectorStoreBase) and issubclass(
cfg_cls, VectorStoreConfig
):
return store_cls, cfg_cls
raise Exception(f"Vector store {vector_store_type} not supported")

View File

@@ -150,6 +150,8 @@ class MilvusStore(VectorStoreBase):
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
"""
super().__init__()
self._vector_store_config = vector_store_config
try:
from pymilvus import connections
except ImportError:
@@ -363,6 +365,10 @@ class MilvusStore(VectorStoreBase):
return res.primary_keys
def get_config(self) -> MilvusVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database."""
batch_size = 500

View File

@@ -718,6 +718,8 @@ class OceanBaseStore(VectorStoreBase):
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
self._vector_store_config = vector_store_config
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
vector_store_config = vector_store_config.dict()
@@ -760,6 +762,10 @@ class OceanBaseStore(VectorStoreBase):
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
)
def get_config(self) -> OceanBaseConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
) -> List[Chunk]:

View File

@@ -64,6 +64,8 @@ class PGVectorStore(VectorStoreBase):
"Please install the `langchain` package to use the PGVector."
)
super().__init__()
self._vector_store_config = vector_store_config
self.connection_string = vector_store_config.connection_string
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
@@ -74,6 +76,10 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string,
)
def get_config(self) -> PGVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:

View File

@@ -69,6 +69,8 @@ class WeaviateStore(VectorStoreBase):
"Please install it with `pip install weaviate-client`."
)
super().__init__()
self._vector_store_config = vector_store_config
self.weaviate_url = vector_store_config.weaviate_url
self.embedding = vector_store_config.embedding_fn
self.vector_name = vector_store_config.name
@@ -78,6 +80,10 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url)
def get_config(self) -> WeaviateVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: