mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
refactor(GraphRAG): refine config usage and fix some bug. (#2392)
Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yyhhyyyyyy <95077259+yyhhyyyyyy@users.noreply.github.com> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
d5a2a0bf3b
commit
3bd75d8de2
@ -35,8 +35,26 @@ host="127.0.0.1"
|
|||||||
port=7687
|
port=7687
|
||||||
username="admin"
|
username="admin"
|
||||||
password="73@TuGraph"
|
password="73@TuGraph"
|
||||||
#enable_summary="True"
|
|
||||||
#enable_similarity_search="True"
|
# enable_summary="True"
|
||||||
|
# community_topk=20
|
||||||
|
# community_score_threshold=0.3
|
||||||
|
|
||||||
|
# triplet_graph_enabled="True"
|
||||||
|
# extract_topk=20
|
||||||
|
|
||||||
|
# document_graph_enabled="True"
|
||||||
|
# knowledge_graph_chunk_search_top_size=20
|
||||||
|
# knowledge_graph_extraction_batch_size=20
|
||||||
|
|
||||||
|
# enable_similarity_search="True"
|
||||||
|
# knowledge_graph_embedding_batch_size=20
|
||||||
|
# similarity_search_topk=5
|
||||||
|
# extract_score_threshold=0.7
|
||||||
|
|
||||||
|
# enable_text_search="True"
|
||||||
|
# text2gql_model_enabled="True"
|
||||||
|
# text2gql_model_name="qwen2.5:latest"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,24 +37,18 @@ class GraphRetriever(GraphRetrieverBase):
|
|||||||
graph_store_adapter,
|
graph_store_adapter,
|
||||||
):
|
):
|
||||||
"""Initialize Graph Retriever."""
|
"""Initialize Graph Retriever."""
|
||||||
self._triplet_graph_enabled = (
|
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||||
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
|
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||||
if "TRIPLET_GRAPH_ENABLED" in os.environ
|
|
||||||
else config.triplet_graph_enabled
|
|
||||||
)
|
)
|
||||||
self._document_graph_enabled = (
|
self._document_graph_enabled = config.document_graph_enabled or (
|
||||||
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
|
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||||
if "DOCUMENT_GRAPH_ENABLED" in os.environ
|
|
||||||
else config.document_graph_enabled
|
|
||||||
)
|
)
|
||||||
triplet_topk = int(
|
triplet_topk = int(
|
||||||
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
|
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||||
)
|
)
|
||||||
document_topk = int(
|
document_topk = int(
|
||||||
os.getenv(
|
config.knowledge_graph_chunk_search_top_size
|
||||||
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
|
or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
|
||||||
config.knowledge_graph_chunk_search_top_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
llm_client = config.llm_client
|
llm_client = config.llm_client
|
||||||
model_name = config.model_name
|
model_name = config.model_name
|
||||||
@ -62,27 +56,43 @@ class GraphRetriever(GraphRetrieverBase):
|
|||||||
graph_store_adapter.graph_store.enable_similarity_search
|
graph_store_adapter.graph_store.enable_similarity_search
|
||||||
)
|
)
|
||||||
self._embedding_batch_size = int(
|
self._embedding_batch_size = int(
|
||||||
os.getenv(
|
config.knowledge_graph_embedding_batch_size
|
||||||
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
|
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||||
config.knowledge_graph_embedding_batch_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
similarity_search_topk = int(
|
similarity_search_topk = int(
|
||||||
os.getenv(
|
config.similarity_search_topk
|
||||||
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
|
or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
|
||||||
config.similarity_search_topk,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
similarity_search_score_threshold = float(
|
similarity_search_score_threshold = float(
|
||||||
os.getenv(
|
config.extract_score_threshold
|
||||||
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
|
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||||
config.extract_score_threshold,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._enable_text_search = (
|
self._enable_text_search = config.enable_text_search or (
|
||||||
os.environ["TEXT_SEARCH_ENABLED"].lower() == "true"
|
os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true"
|
||||||
if "TEXT_SEARCH_ENABLED" in os.environ
|
)
|
||||||
else config.enable_text_search
|
text2gql_model_enabled = config.text2gql_model_enabled or (
|
||||||
|
os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true"
|
||||||
|
)
|
||||||
|
text2gql_model_name = config.text2gql_model_name or os.getenv(
|
||||||
|
"TEXT2GQL_MODEL_NAME"
|
||||||
|
)
|
||||||
|
text2gql_model_enabled = (
|
||||||
|
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||||
|
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||||
|
else config.text2gql_model_enabled
|
||||||
|
)
|
||||||
|
text2gql_model_name = os.getenv(
|
||||||
|
"TEXT2GQL_MODEL_NAME",
|
||||||
|
config.text2gql_model_name,
|
||||||
|
)
|
||||||
|
text2gql_model_enabled = (
|
||||||
|
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||||
|
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||||
|
else config.text2gql_model_enabled
|
||||||
|
)
|
||||||
|
text2gql_model_name = os.getenv(
|
||||||
|
"TEXT2GQL_MODEL_NAME",
|
||||||
|
config.text2gql_model_name,
|
||||||
)
|
)
|
||||||
text2gql_model_enabled = (
|
text2gql_model_enabled = (
|
||||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||||
|
@ -95,7 +95,7 @@ logger = logging.getLogger(__name__)
|
|||||||
float,
|
float,
|
||||||
description=_("Recall score of community search in knowledge graph"),
|
description=_("Recall score of community search in knowledge graph"),
|
||||||
optional=True,
|
optional=True,
|
||||||
default=0.0,
|
default=0.3,
|
||||||
),
|
),
|
||||||
Parameter.build_from(
|
Parameter.build_from(
|
||||||
_("Enable the graph search for documents and chunks"),
|
_("Enable the graph search for documents and chunks"),
|
||||||
@ -171,7 +171,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
|||||||
description="Topk of community search in knowledge graph",
|
description="Topk of community search in knowledge graph",
|
||||||
)
|
)
|
||||||
community_score_threshold: float = Field(
|
community_score_threshold: float = Field(
|
||||||
default=0.0,
|
default=0.3,
|
||||||
description="Recall score of community search in knowledge graph",
|
description="Recall score of community search in knowledge graph",
|
||||||
)
|
)
|
||||||
triplet_graph_enabled: bool = Field(
|
triplet_graph_enabled: bool = Field(
|
||||||
@ -244,56 +244,41 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._vector_store_type = os.getenv(
|
self._vector_store_type = config.vector_store_type or os.getenv(
|
||||||
"VECTOR_STORE_TYPE", config.vector_store_type
|
"VECTOR_STORE_TYPE"
|
||||||
)
|
)
|
||||||
self._extract_topk = int(
|
self._extract_topk = int(
|
||||||
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
|
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||||
)
|
)
|
||||||
self._extract_score_threshold = float(
|
self._extract_score_threshold = float(
|
||||||
os.getenv(
|
config.extract_score_threshold
|
||||||
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
|
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||||
config.extract_score_threshold,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._community_topk = int(
|
self._community_topk = int(
|
||||||
os.getenv(
|
config.community_topk
|
||||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE", config.community_topk
|
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._community_score_threshold = float(
|
self._community_score_threshold = float(
|
||||||
os.getenv(
|
config.community_score_threshold
|
||||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE",
|
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE")
|
||||||
config.community_score_threshold,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._document_graph_enabled = (
|
self._document_graph_enabled = config.document_graph_enabled or (
|
||||||
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
|
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||||
if "DOCUMENT_GRAPH_ENABLED" in os.environ
|
|
||||||
else config.document_graph_enabled
|
|
||||||
)
|
)
|
||||||
self._triplet_graph_enabled = (
|
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||||
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
|
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||||
if "TRIPLET_GRAPH_ENABLED" in os.environ
|
|
||||||
else config.triplet_graph_enabled
|
|
||||||
)
|
)
|
||||||
self._triplet_extraction_batch_size = int(
|
self._triplet_extraction_batch_size = int(
|
||||||
os.getenv(
|
config.knowledge_graph_extraction_batch_size
|
||||||
"KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE",
|
or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE")
|
||||||
config.knowledge_graph_extraction_batch_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._triplet_embedding_batch_size = int(
|
self._triplet_embedding_batch_size = int(
|
||||||
os.getenv(
|
config.knowledge_graph_embedding_batch_size
|
||||||
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
|
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||||
config.knowledge_graph_embedding_batch_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._community_summary_batch_size = int(
|
self._community_summary_batch_size = int(
|
||||||
os.getenv(
|
config.community_summary_batch_size
|
||||||
"COMMUNITY_SUMMARY_BATCH_SIZE",
|
or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
|
||||||
config.community_summary_batch_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||||
|
@ -149,7 +149,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
|||||||
cfg.name = config.name
|
cfg.name = config.name
|
||||||
cfg.embedding_fn = config.embedding_fn
|
cfg.embedding_fn = config.embedding_fn
|
||||||
|
|
||||||
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.type
|
graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE")
|
||||||
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
|
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
|
||||||
|
|
||||||
def __init_graph_store_adapter(self):
|
def __init_graph_store_adapter(self):
|
||||||
|
@ -96,7 +96,7 @@ class VectorStoreConnector:
|
|||||||
|
|
||||||
def __rewrite_index_store_type(self, index_store_type):
|
def __rewrite_index_store_type(self, index_store_type):
|
||||||
# Rewrite Knowledge Graph Type
|
# Rewrite Knowledge Graph Type
|
||||||
if self.app_config.rag.graph_community_summary_enabled:
|
if self.app_config.rag.storage.graph.get("enable_summary").lower() == "true":
|
||||||
if index_store_type == "KnowledgeGraph":
|
if index_store_type == "KnowledgeGraph":
|
||||||
return "CommunitySummaryKnowledgeGraph"
|
return "CommunitySummaryKnowledgeGraph"
|
||||||
return index_store_type
|
return index_store_type
|
||||||
|
Loading…
Reference in New Issue
Block a user