mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 07:34:07 +00:00
refactor(GraphRAG): refine config usage and fix some bug. (#2392)
Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yyhhyyyyyy <95077259+yyhhyyyyyy@users.noreply.github.com> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
d5a2a0bf3b
commit
3bd75d8de2
@ -35,8 +35,26 @@ host="127.0.0.1"
|
||||
port=7687
|
||||
username="admin"
|
||||
password="73@TuGraph"
|
||||
#enable_summary="True"
|
||||
#enable_similarity_search="True"
|
||||
|
||||
# enable_summary="True"
|
||||
# community_topk=20
|
||||
# community_score_threshold=0.3
|
||||
|
||||
# triplet_graph_enabled="True"
|
||||
# extract_topk=20
|
||||
|
||||
# document_graph_enabled="True"
|
||||
# knowledge_graph_chunk_search_top_size=20
|
||||
# knowledge_graph_extraction_batch_size=20
|
||||
|
||||
# enable_similarity_search="True"
|
||||
# knowledge_graph_embedding_batch_size=20
|
||||
# similarity_search_topk=5
|
||||
# extract_score_threshold=0.7
|
||||
|
||||
# enable_text_search="True"
|
||||
# text2gql_model_enabled="True"
|
||||
# text2gql_model_name="qwen2.5:latest"
|
||||
|
||||
|
||||
|
||||
|
@ -37,24 +37,18 @@ class GraphRetriever(GraphRetrieverBase):
|
||||
graph_store_adapter,
|
||||
):
|
||||
"""Initialize Graph Retriever."""
|
||||
self._triplet_graph_enabled = (
|
||||
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
|
||||
if "TRIPLET_GRAPH_ENABLED" in os.environ
|
||||
else config.triplet_graph_enabled
|
||||
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._document_graph_enabled = (
|
||||
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
|
||||
if "DOCUMENT_GRAPH_ENABLED" in os.environ
|
||||
else config.document_graph_enabled
|
||||
self._document_graph_enabled = config.document_graph_enabled or (
|
||||
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
triplet_topk = int(
|
||||
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
|
||||
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
)
|
||||
document_topk = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
|
||||
config.knowledge_graph_chunk_search_top_size,
|
||||
)
|
||||
config.knowledge_graph_chunk_search_top_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE")
|
||||
)
|
||||
llm_client = config.llm_client
|
||||
model_name = config.model_name
|
||||
@ -62,27 +56,43 @@ class GraphRetriever(GraphRetrieverBase):
|
||||
graph_store_adapter.graph_store.enable_similarity_search
|
||||
)
|
||||
self._embedding_batch_size = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
|
||||
config.knowledge_graph_embedding_batch_size,
|
||||
)
|
||||
config.knowledge_graph_embedding_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
)
|
||||
similarity_search_topk = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
|
||||
config.similarity_search_topk,
|
||||
)
|
||||
config.similarity_search_topk
|
||||
or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE")
|
||||
)
|
||||
similarity_search_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
|
||||
config.extract_score_threshold,
|
||||
)
|
||||
config.extract_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._enable_text_search = (
|
||||
os.environ["TEXT_SEARCH_ENABLED"].lower() == "true"
|
||||
if "TEXT_SEARCH_ENABLED" in os.environ
|
||||
else config.enable_text_search
|
||||
self._enable_text_search = config.enable_text_search or (
|
||||
os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
text2gql_model_enabled = config.text2gql_model_enabled or (
|
||||
os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true"
|
||||
)
|
||||
text2gql_model_name = config.text2gql_model_name or os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME"
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
|
@ -95,7 +95,7 @@ logger = logging.getLogger(__name__)
|
||||
float,
|
||||
description=_("Recall score of community search in knowledge graph"),
|
||||
optional=True,
|
||||
default=0.0,
|
||||
default=0.3,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Enable the graph search for documents and chunks"),
|
||||
@ -171,7 +171,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
description="Topk of community search in knowledge graph",
|
||||
)
|
||||
community_score_threshold: float = Field(
|
||||
default=0.0,
|
||||
default=0.3,
|
||||
description="Recall score of community search in knowledge graph",
|
||||
)
|
||||
triplet_graph_enabled: bool = Field(
|
||||
@ -244,56 +244,41 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
super().__init__(config)
|
||||
self._config = config
|
||||
|
||||
self._vector_store_type = os.getenv(
|
||||
"VECTOR_STORE_TYPE", config.vector_store_type
|
||||
self._vector_store_type = config.vector_store_type or os.getenv(
|
||||
"VECTOR_STORE_TYPE"
|
||||
)
|
||||
self._extract_topk = int(
|
||||
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
|
||||
config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE")
|
||||
)
|
||||
self._extract_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
|
||||
config.extract_score_threshold,
|
||||
)
|
||||
config.extract_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._community_topk = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE", config.community_topk
|
||||
)
|
||||
config.community_topk
|
||||
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE")
|
||||
)
|
||||
self._community_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE",
|
||||
config.community_score_threshold,
|
||||
)
|
||||
config.community_score_threshold
|
||||
or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE")
|
||||
)
|
||||
self._document_graph_enabled = (
|
||||
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
|
||||
if "DOCUMENT_GRAPH_ENABLED" in os.environ
|
||||
else config.document_graph_enabled
|
||||
self._document_graph_enabled = config.document_graph_enabled or (
|
||||
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._triplet_graph_enabled = (
|
||||
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
|
||||
if "TRIPLET_GRAPH_ENABLED" in os.environ
|
||||
else config.triplet_graph_enabled
|
||||
self._triplet_graph_enabled = config.triplet_graph_enabled or (
|
||||
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
|
||||
)
|
||||
self._triplet_extraction_batch_size = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE",
|
||||
config.knowledge_graph_extraction_batch_size,
|
||||
)
|
||||
config.knowledge_graph_extraction_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE")
|
||||
)
|
||||
self._triplet_embedding_batch_size = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
|
||||
config.knowledge_graph_embedding_batch_size,
|
||||
)
|
||||
config.knowledge_graph_embedding_batch_size
|
||||
or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE")
|
||||
)
|
||||
self._community_summary_batch_size = int(
|
||||
os.getenv(
|
||||
"COMMUNITY_SUMMARY_BATCH_SIZE",
|
||||
config.community_summary_batch_size,
|
||||
)
|
||||
config.community_summary_batch_size
|
||||
or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE")
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
|
@ -149,7 +149,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
cfg.name = config.name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
|
||||
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.type
|
||||
graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE")
|
||||
return GraphStoreFactory.create(graph_store_type, configure, config.dict())
|
||||
|
||||
def __init_graph_store_adapter(self):
|
||||
|
@ -96,7 +96,7 @@ class VectorStoreConnector:
|
||||
|
||||
def __rewrite_index_store_type(self, index_store_type):
|
||||
# Rewrite Knowledge Graph Type
|
||||
if self.app_config.rag.graph_community_summary_enabled:
|
||||
if self.app_config.rag.storage.graph.get("enable_summary").lower() == "true":
|
||||
if index_store_type == "KnowledgeGraph":
|
||||
return "CommunitySummaryKnowledgeGraph"
|
||||
return index_store_type
|
||||
|
Loading…
Reference in New Issue
Block a user