refactor(GraphRAG): refine config usage and fix some bug. (#2392)

Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yyhhyyyyyy <95077259+yyhhyyyyyy@users.noreply.github.com>
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
SonglinLyu 2025-03-06 03:09:10 +08:00 committed by GitHub
parent d5a2a0bf3b
commit 3bd75d8de2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 69 deletions

View File

@ -35,8 +35,26 @@ host="127.0.0.1"
port=7687
username="admin"
password="73@TuGraph"
#enable_summary="True"
#enable_similarity_search="True"
# enable_summary="True"
# community_topk=20
# community_score_threshold=0.3
# triplet_graph_enabled="True"
# extract_topk=20
# document_graph_enabled="True"
# knowledge_graph_chunk_search_top_size=20
# knowledge_graph_extraction_batch_size=20
# enable_similarity_search="True"
# knowledge_graph_embedding_batch_size=20
# similarity_search_topk=5
# extract_score_threshold=0.7
# enable_text_search="True"
# text2gql_model_enabled="True"
# text2gql_model_name="qwen2.5:latest"

View File

@ -37,24 +37,18 @@ class GraphRetriever(GraphRetrieverBase):
graph_store_adapter,
):
"""Initialize Graph Retriever."""
self._triplet_graph_enabled = (
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
if "TRIPLET_GRAPH_ENABLED" in os.environ
else config.triplet_graph_enabled
self._triplet_graph_enabled = config.triplet_graph_enabled or (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
)
self._document_graph_enabled = (
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
if "DOCUMENT_GRAPH_ENABLED" in os.environ
else config.document_graph_enabled
self._document_graph_enabled = config.document_graph_enabled or (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
)
triplet_topk = int(
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
)
document_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
config.knowledge_graph_chunk_search_top_size,
)
config.knowledge_graph_chunk_search_top_size
or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
)
llm_client = config.llm_client
model_name = config.model_name
@ -62,27 +56,43 @@ class GraphRetriever(GraphRetrieverBase):
graph_store_adapter.graph_store.enable_similarity_search
)
self._embedding_batch_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
config.knowledge_graph_embedding_batch_size,
)
config.knowledge_graph_embedding_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
)
similarity_search_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
config.similarity_search_topk,
)
config.similarity_search_topk
or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
)
similarity_search_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
config.extract_score_threshold,
)
config.extract_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
)
self._enable_text_search = (
os.environ["TEXT_SEARCH_ENABLED"].lower() == "true"
if "TEXT_SEARCH_ENABLED" in os.environ
else config.enable_text_search
self._enable_text_search = config.enable_text_search or (
os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true"
)
text2gql_model_enabled = config.text2gql_model_enabled or (
os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true"
)
text2gql_model_name = config.text2gql_model_name or os.getenv(
"TEXT2GQL_MODEL_NAME"
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"

View File

@ -95,7 +95,7 @@ logger = logging.getLogger(__name__)
float,
description=_("Recall score of community search in knowledge graph"),
optional=True,
default=0.0,
default=0.3,
),
Parameter.build_from(
_("Enable the graph search for documents and chunks"),
@ -171,7 +171,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
description="Topk of community search in knowledge graph",
)
community_score_threshold: float = Field(
default=0.0,
default=0.3,
description="Recall score of community search in knowledge graph",
)
triplet_graph_enabled: bool = Field(
@ -244,56 +244,41 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
super().__init__(config)
self._config = config
self._vector_store_type = os.getenv(
"VECTOR_STORE_TYPE", config.vector_store_type
self._vector_store_type = config.vector_store_type or os.getenv(
"VECTOR_STORE_TYPE"
)
self._extract_topk = int(
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
)
self._extract_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
config.extract_score_threshold,
)
config.extract_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
)
self._community_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE", config.community_topk
)
config.community_topk
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
)
self._community_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE",
config.community_score_threshold,
)
config.community_score_threshold
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE")
)
self._document_graph_enabled = (
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
if "DOCUMENT_GRAPH_ENABLED" in os.environ
else config.document_graph_enabled
self._document_graph_enabled = config.document_graph_enabled or (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
)
self._triplet_graph_enabled = (
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
if "TRIPLET_GRAPH_ENABLED" in os.environ
else config.triplet_graph_enabled
self._triplet_graph_enabled = config.triplet_graph_enabled or (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
)
self._triplet_extraction_batch_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE",
config.knowledge_graph_extraction_batch_size,
)
config.knowledge_graph_extraction_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE")
)
self._triplet_embedding_batch_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
config.knowledge_graph_embedding_batch_size,
)
config.knowledge_graph_embedding_batch_size
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
)
self._community_summary_batch_size = int(
os.getenv(
"COMMUNITY_SUMMARY_BATCH_SIZE",
config.community_summary_batch_size,
)
config.community_summary_batch_size
or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
)
def extractor_configure(name: str, cfg: VectorStoreConfig):

View File

@ -149,7 +149,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
cfg.name = config.name
cfg.embedding_fn = config.embedding_fn
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.type
graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE")
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
def __init_graph_store_adapter(self):

View File

@ -96,7 +96,7 @@ class VectorStoreConnector:
def __rewrite_index_store_type(self, index_store_type):
# Rewrite Knowledge Graph Type
if self.app_config.rag.graph_community_summary_enabled:
if self.app_config.rag.storage.graph.get("enable_summary").lower() == "true":
if index_store_type == "KnowledgeGraph":
return "CommunitySummaryKnowledgeGraph"
return index_store_type