feat: Enhance the triplets extraction in the knowledge graph by the batch size (#2091)

This commit is contained in:
Appointat
2024-11-05 14:01:18 +08:00
committed by GitHub
parent b4ce217ded
commit 25d47ce343
10 changed files with 360 additions and 242 deletions

View File

@@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
### Chroma vector db config ### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data #CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

View File

@@ -1,4 +1,5 @@
"""Transformer base class.""" """Transformer base class."""
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
@@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List: async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text.""" """Extract results from text."""
@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""
class TranslatorBase(TransformerBase, ABC): class TranslatorBase(TransformerBase, ABC):
"""Translator base class.""" """Translator base class."""

View File

@@ -1,8 +1,9 @@
"""GraphExtractor class.""" """GraphExtractor class."""
import asyncio
import logging import logging
import re import re
from typing import List, Optional from typing import Dict, List, Optional
from dbgpt.core import Chunk, LLMClient from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor from dbgpt.rag.transformer.llm_extractor import LLMExtractor
@@ -23,35 +24,96 @@ class GraphExtractor(LLMExtractor):
self._chunk_history = chunk_history self._chunk_history = chunk_history
config = self._chunk_history.get_config() config = self._chunk_history.get_config()
self._vector_space = config.name self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads self._max_threads = config.max_threads
self._topk = config.topk self._topk = config.topk
self._score_threshold = config.score_threshold self._score_threshold = config.score_threshold
async def extract(self, text: str, limit: Optional[int] = None) -> List: async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load similar chunks.""" """Load chunk context."""
# load similar chunks text_context_map: Dict[str, str] = {}
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""
try: for text in texts:
# extract with chunk history # Load similar chunks
return await super()._extract(text, context, limit) chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
finally: # Save chunk to history
# save chunk to history
await self._chunk_history.aload_document_with_limit( await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})], [Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load, self._max_chunks_once_load,
self._max_threads, self._max_threads,
) )
# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.
Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]
# Extract with chunk history
return await super()._extract(text, context, limit)
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.
Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")
# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)
# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]
# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)
# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs
assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]: def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph() graph = MemoryGraph()
edge_count = 0 edge_count = 0

View File

@@ -1,4 +1,6 @@
"""TripletExtractor class.""" """TripletExtractor class."""
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
@@ -22,6 +24,32 @@ class LLMExtractor(ExtractorBase, ABC):
"""Extract by LLM.""" """Extract by LLM."""
return await self._extract(text, None, limit) return await self._extract(text, None, limit)
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]
# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)
return results
async def _extract( async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None self, text: str, history: str = None, limit: Optional[int] = None
) -> List: ) -> List:

View File

@@ -1,4 +1,5 @@
"""TripletExtractor class.""" """TripletExtractor class."""
import logging import logging
import re import re
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
@@ -12,7 +13,7 @@ TRIPLET_EXTRACT_PT = (
"Some text is provided below. Given the text, " "Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible " "extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n" "in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n" "Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n" "---------------------\n"
"Example:\n" "Example:\n"
"Text: Alice is Bob's mother.\n" "Text: Alice is Bob's mother.\n"

View File

@@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False, default=False,
description="Enable graph community summary or not.", description="Enable graph community summary or not.",
) )
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)
class GraphStoreBase(ABC): class GraphStoreBase(ABC):

View File

@@ -83,14 +83,6 @@ class TuGraphStore(GraphStoreBase):
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true" os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary or config.enable_summary
) )
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = ( self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",") os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names or config.plugin_names

View File

@@ -544,7 +544,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
if not subs: if not subs:
return MemoryGraph() return MemoryGraph()
if depth < 0: if depth <= 0:
depth = 3 depth = 3
depth_string = f"1..{depth}" depth_string = f"1..{depth}"
@@ -566,23 +566,95 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN p {limit_string}" f"RETURN p {limit_string}"
) )
return self.query(query) return self.query(query=query, white_list=["description"])
else: else:
# If there exists the entities in the graph, return the graph that
# includes the leaf chunks that connect to the entities, the chains from
# documents to the leaf chunks, and the chain from documents to chunks;
# document -> chunk -> chunk -> ... -> leaf chunk -> (entity)
#
# If not, return the graph that includes the chains from documents to chunks
# that contain the subs (keywords).
# document -> chunk -> chunk -> ... -> leaf chunk (that contains the subs)
#
# And only the leaf chunks contain the content, and the other chunks do not
# contain any properties except the id, name.
graph = MemoryGraph() graph = MemoryGraph()
for sub in subs: # Check if the entities exist in the graph
query = ( check_entity_query = (
f"MATCH (n:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
"RETURN n"
)
if self.query(check_entity_query):
# Query the leaf chunks in the chain from documents to chunks
leaf_chunk_query = (
f"MATCH p=(n:{GraphElemType.CHUNK.value})-"
f"[r:{GraphElemType.INCLUDE.value}]->"
f"(m:{GraphElemType.ENTITY.value})"
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN n"
)
graph_of_leaf_chunks = self.query(
query=leaf_chunk_query, white_list=["content"]
)
# Query the chain from documents to chunks,
# document -> chunk -> ... -> leaf_chunks
chunk_names = [
self._escape_quotes(vertex.name)
for vertex in graph_of_leaf_chunks.vertices()
]
chain_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-" f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-" f"[:{GraphElemType.INCLUDE.value}*{depth_string}]->"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS " f"(m:{GraphElemType.CHUNK.value})"
f"'{self._escape_quotes(sub)}' " f"WHERE m.name IN {chunk_names} "
f"RETURN p {limit_string}" "RETURN p"
) # if it contains the subjects )
result = self.query(query) # Filter all the properties by with_list
for vertex in result.vertices(): graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
graph.upsert_vertex(vertex)
for edge in result.edges(): # The number of leaf chunks caompared to the `limit`
graph.append_edge(edge) if not limit or len(chunk_names) <= limit:
graph.upsert_graph(graph_of_leaf_chunks)
else:
limited_leaf_chunk_query = leaf_chunk_query + f" {limit_string}"
graph.upsert_graph(
self.query(
query=limited_leaf_chunk_query, white_list=["content"]
)
)
else:
_subs_condition = " OR ".join(
[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs]
)
# Query the chain from documents to chunks,
# document -> chunk -> chunk -> chunk -> ... -> chunk
chain_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
f"(m:{GraphElemType.CHUNK.value})"
f"WHERE {_subs_condition}"
"RETURN p"
)
# Filter all the properties by with_list
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
# Query the leaf chunks in the chain from documents to chunks
leaf_chunk_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
f"(m:{GraphElemType.CHUNK.value})"
f"WHERE {_subs_condition}"
f"RETURN m {limit_string}"
)
graph.upsert_graph(
self.query(query=leaf_chunk_query, white_list=["content"])
)
return graph return graph
@@ -607,6 +679,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
vertices, edges = self._get_nodes_edges_from_queried_data( vertices, edges = self._get_nodes_edges_from_queried_data(
query_result, white_list query_result, white_list
) )
mg = MemoryGraph() mg = MemoryGraph()
for vertex in vertices: for vertex in vertices:
mg.upsert_vertex(vertex) mg.upsert_vertex(vertex)
@@ -714,7 +787,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
from neo4j import graph from neo4j import graph
def filter_properties( def filter_properties(
properties: dict[str, Any], white_list: List[str] properties: dict[str, Any], white_list: Optional[List[str]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Filter the properties. """Filter the properties.
@@ -723,13 +796,26 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
entity_properties = ["id", "name", "description", "_document_id", entity_properties = ["id", "name", "description", "_document_id",
"_chunk_id", "_community_id"] "_chunk_id", "_community_id"]
edge_properties = ["id", "name", "description", "_chunk_id"] edge_properties = ["id", "name", "description", "_chunk_id"]
Args:
properties: Dictionary of properties to filter
white_list: List of properties to keep
- If None: Keep default properties (those not starting with '_'
and not in ['id', 'name'])
- If [""]: Remove all properties (return empty dict)
- If list of strings: Keep only properties in white_list
""" """
return { return (
key: value {}
for key, value in properties.items() if white_list == [""]
if (not key.startswith("_") and key not in ["id", "name"]) else {
or key in white_list key: value
} for key, value in properties.items()
if (
(not key.startswith("_") and key not in ["id", "name"])
or (white_list is not None and key in white_list)
)
}
)
# Parse the data to nodes and relationships # Parse the data to nodes and relationships
for record in data: for record in data:

View File

@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_extractor import GraphExtractor from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.base import ParagraphChunk from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.knowledge_graph import ( from dbgpt.storage.knowledge_graph.knowledge_graph import (
@@ -59,10 +58,23 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=0.0, default=0.0,
description="Recall score of community search in knowledge graph", description="Recall score of community search in knowledge graph",
) )
triplet_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for triplets",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for documents and chunks",
)
knowledge_graph_chunk_search_top_size: int = Field( knowledge_graph_chunk_search_top_size: int = Field(
default=5, default=5,
description="Top size of knowledge graph chunk search", description="Top size of knowledge graph chunk search",
) )
knowledge_graph_extraction_batch_size: int = Field(
default=20,
description="Batch size of triplets extraction from the text",
)
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
@@ -96,6 +108,28 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
config.community_score_threshold, config.community_score_threshold,
) )
) )
self._document_graph_enabled = (
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
if "DOCUMENT_GRAPH_ENABLED" in os.environ
else config.document_graph_enabled
)
self._triplet_graph_enabled = (
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
if "TRIPLET_GRAPH_ENABLED" in os.environ
else config.triplet_graph_enabled
)
self._knowledge_graph_chunk_search_top_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
config.knowledge_graph_chunk_search_top_size,
)
)
self._triplet_extraction_batch_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE",
config.knowledge_graph_extraction_batch_size,
)
)
def extractor_configure(name: str, cfg: VectorStoreConfig): def extractor_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name cfg.name = name
@@ -154,7 +188,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
The chunks include the doc structure. The chunks include the doc structure.
""" """
if not self._graph_store.get_config().document_graph_enabled: if not self._document_graph_enabled:
return return
_chunks: List[ParagraphChunk] = [ _chunks: List[ParagraphChunk] = [
@@ -185,33 +219,35 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
The chunks include the doc structure. The chunks include the doc structure.
""" """
if not self._graph_store.get_config().triplet_graph_enabled: if not self._triplet_graph_enabled:
return return
document_graph_enabled = self._graph_store.get_config().document_graph_enabled document_graph_enabled = self._document_graph_enabled
for chunk in chunks:
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract( # Extract the triplets from the chunks, and return the list of graphs
chunk.content # in the same order as the input texts
) graphs_list = await self._graph_extractor.batch_extract(
[chunk.content for chunk in chunks],
batch_size=self._triplet_extraction_batch_size,
)
# Upsert the graphs into the graph store
for idx, graphs in enumerate(graphs_list):
for graph in graphs: for graph in graphs:
if document_graph_enabled: if document_graph_enabled:
# append the chunk id to the edge # Append the chunk id to the edge
for edge in graph.edges(): for edge in graph.edges():
edge.set_prop("_chunk_id", chunk.chunk_id) edge.set_prop("_chunk_id", chunks[idx].chunk_id)
graph.append_edge(edge=edge) graph.append_edge(edge=edge)
# upsert the graph # Upsert the graph
self._graph_store_apdater.upsert_graph(graph) self._graph_store_apdater.upsert_graph(graph)
# chunk -> include -> entity # chunk -> include -> entity
if document_graph_enabled: if document_graph_enabled:
for vertex in graph.vertices(): for vertex in graph.vertices():
self._graph_store_apdater.upsert_chunk_include_entity( self._graph_store_apdater.upsert_chunk_include_entity(
chunk=chunk, entity=vertex chunk=chunks[idx], entity=vertex
) )
def _load_chunks( def _load_chunks(
@@ -285,13 +321,15 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
context = "\n".join(summaries) if summaries else "" context = "\n".join(summaries) if summaries else ""
keywords: List[str] = await self._keyword_extractor.extract(text) keywords: List[str] = await self._keyword_extractor.extract(text)
subgraph = None
subgraph_for_doc = None
# Local search: extract keywords and explore subgraph # Local search: extract keywords and explore subgraph
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled triplet_graph_enabled = self._triplet_graph_enabled
document_graph_enabled = self._graph_store.get_config().document_graph_enabled document_graph_enabled = self._document_graph_enabled
if triplet_graph_enabled: if triplet_graph_enabled:
subgraph: MemoryGraph = self._graph_store_apdater.explore( subgraph = self._graph_store_apdater.explore(
subs=keywords, limit=topk, search_scope="knowledge_graph" subs=keywords, limit=topk, search_scope="knowledge_graph"
) )
@@ -302,14 +340,14 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
subgraph_for_doc = self._graph_store_apdater.explore( subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords_for_document_graph, subs=keywords_for_document_graph,
limit=self._config.knowledge_graph_chunk_search_top_size, limit=self._knowledge_graph_chunk_search_top_size,
search_scope="document_graph", search_scope="document_graph",
) )
else: else:
if document_graph_enabled: if document_graph_enabled:
subgraph_for_doc = self._graph_store_apdater.explore( subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords, subs=keywords,
limit=self._config.knowledge_graph_chunk_search_top_size, limit=self._knowledge_graph_chunk_search_top_size,
search_scope="document_graph", search_scope="document_graph",
) )
knowledge_graph_str = subgraph.format() if subgraph else "" knowledge_graph_str = subgraph.format() if subgraph else ""
@@ -323,7 +361,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
return [] return []
# merge search results into context # merge search results into context
content = HYBRID_SEARCH_PT_CN.format( content = HYBRID_SEARCH_PT.format(
context=context, context=context,
knowledge_graph=knowledge_graph_str, knowledge_graph=knowledge_graph_str,
knowledge_graph_for_doc=knowledge_graph_for_doc_str, knowledge_graph_for_doc=knowledge_graph_for_doc_str,
@@ -353,179 +391,86 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
self._graph_extractor.drop() self._graph_extractor.drop()
HYBRID_SEARCH_PT_CN = """## 角色 HYBRID_SEARCH_PT = """
你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息, =====
准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。 The following information from [Context], [Knowledge Graph], and [Original Text From RAG] can help you answer user questions better.
## 技能
### 技能 1: 上下文理解
- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。
- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。
- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。
### 技能 2: 知识图谱理解
- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:
```
* 实体信息格式:
- (实体名)
- (实体名:实体描述)
- (实体名:实体属性表)
- (文本块ID:文档块内容)
- (目录ID:目录名)
- (文档ID:文档名称)
* 关系信息的格式:
- (来源实体名)-[关系名]->(目标实体名)
- (来源实体名)-[关系名:关系描述]->(目标实体名)
- (来源实体名)-[关系名:关系属性表]->(目标实体名)
- (文本块实体)-[包含]->(实体名)
- (目录ID)-[包含]->(文本块实体)
- (目录ID)-[包含]->(子目录ID)
- (文档ID)-[包含]->(文本块实体)
- (文档ID)-[包含]->(目录ID)
```
- 正确地将关系信息中的实体名/ID与实体信息关联还原出图结构。
- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。
## 约束条件
- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。
- 若[知识图谱]或者[知识库原文]没有提供信息,此时应根据[上下文]提供的信息回答问题。
- 确保以第三人称书写,从客观角度结合[上下文]、[知识图谱]和[知识库原文]表达的信息回答问题。
- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。
- 避免使用停用词和过于常见的词汇。
## 参考案例
```
[上下文]:
Section 1:
菲尔・贾伯的大儿子叫雅各布・贾伯。
Section 2:
菲尔・贾伯的小儿子叫比尔・贾伯。
[知识图谱]:
Entities:
(菲尔・贾伯#菲尔兹咖啡创始人)
(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)
(雅各布・贾伯#菲尔・贾伯的儿子)
(美国多地#菲尔兹咖啡的扩展地区)
Relationships:
(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)
(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)
(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)
(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)
(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)
[知识库原文]:
...
```
----
接下来的[上下文]、[知识图谱]和[知识库原文]的信息,可以帮助你回答更好地用户的问题。
[上下文]:
{context}
[知识图谱]:
{knowledge_graph}
[知识库原文]
{knowledge_graph_for_doc}
""" # noqa: E501
HYBRID_SEARCH_PT_EN = """## Role
You excel at combining the information provided in the [Context] with
information from the [KnowledgeGraph] to accurately and appropriately
answer user questions, ensuring that you do not output information
unrelated to the context and knowledge graph.
## Skills
### Skill 1: Context Understanding
- Accurately understand the information provided in the [Context],
which may be divided into several sections.
- Each section in the context will start with [Section]
and may be numbered as needed.
- The context provides a summary description most relevant to the user's
question, and it should be used wisely.
### Skill 2: Knowledge Graph Understanding
- Accurately identify entity information in the [Entities:] section and
relationship information in the [Relationships:] section
of the [KnowledgeGraph]. The general format for entity
and relationship information is:
```
* Entity Information Format:
- (entity_name)
- (entity_name: entity_description)
- (entity_name: entity_property_map)
- (chunk_id: chunk_content)
- (catalog_id: catalog_name)
- (document_id: document_name)
* Relationship Information Format:
- (source_entity_name)-[relationship_name]->(target_entity_name)
- (source_entity_name)-[relationship_name: relationship_description]->(target_entity_name)
- (source_entity_name)-[relationship_name: relationship_property_map]->(target_entity_name)
- (chunk_id)-[Contains]->(entity_name)
- (catalog_id)-[Contains]->(chunk_id)
- (catalog_id)-[Contains]->(sub_catalog_id)
- (document_id)-[Contains]->(chunk_id)
- (document_id)-[Contains]->(catalog_id)
```
- Correctly associate entity names/IDs in the relationship information
with entity information to restore the graph structure.
- Use the information expressed by the graph structure as detailed
context for the user's query to assist in generating better answers.
## Constraints
- Don't describe your thought process in the answer, provide the answer
to the user's question directly without generating irrelevant information.
- If the [KnowledgeGraph] or [Knowledge base original text] does not provide information, you should answer
the question based on the information provided in the [Context].
- Ensure to write in the third person, responding to questions from
an objective perspective based on the information combined from the
[Context], the [KnowledgeGraph] and the [Knowledge base original text].
- If the provided information is contradictory, resolve the
contradictions and provide a single, coherent description.
- Avoid using stop words and overly common vocabulary.
## Reference Example
```
[Context]:
Section 1:
Phil Schiller's eldest son is Jacob Schiller.
Section 2:
Phil Schiller's youngest son is Bill Schiller.
[KnowledgeGraph]:
Entities:
(Phil Jaber#Founder of Philz Coffee)
(Philz Coffee#Coffee brand founded in Berkeley, California)
(Jacob Jaber#Son of Phil Jaber)
(Multiple locations in the USA#Expansion regions of Philz Coffee)
Relationships:
(Phil Jaber#Created#Philz Coffee#Founded in Berkeley, California in 1978)
(Philz Coffee#Located in#Berkeley, California#Founding location of Philz Coffee)
(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)
(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)
(Philz Coffee#Expanded to#Multiple locations in the USA#Expansion regions of Philz Coffee)
[Knowledge base original text]
...
```
----
The following information from the [Context], [KnowledgeGraph] and [Knowledge base original text]
can help you better answer user questions.
[Context]: [Context]:
{context} {context}
[KnowledgeGraph]: [Knowledge Graph]:
{knowledge_graph} {knowledge_graph}
[Knowledge base original text] [Original Text From RAG]
{knowledge_graph_for_doc} {knowledge_graph_for_doc}
=====
You are very good at combining the [Context] information provided by the prompt word template with the [Knowledge Graph] information,
answering the user's questions accurately and appropriately, and ensuring that no information irrelevant to the context and knowledge graph is output.
## Role: GraphRAG Assistant
### Core Capabilities
0. Make sure DO NOT answer irrelevant questions from the user.
1. Information Processing
- Process contextual information across multiple sections ([Section] markers)
- Interpret knowledge graph relationships ((entity)-[relationship]->(entity))
- Synthesize information from both structured and unstructured sources
2. Response Generation
- Provide nuanced, multi-perspective answers
- Balance technical accuracy with conversational engagement
- Connect related concepts across different information sources
- Highlight uncertainties and limitations when appropriate
3. Interaction Style
- Maintain a natural, engaging conversation flow
- Ask clarifying questions when needed
- Provide examples and analogies to illustrate complex points
- Adapt explanation depth based on user's apparent expertise
4. Knowledge Integration
- Seamlessly blend information from:
* Context sections
* Knowledge graph relationships
* Background knowledge (when appropriate)
- Prioritize relevance over comprehensiveness
- Acknowledge information gaps explicitly
5. Quality Assurance
- Verify logical consistency across sources
- Cross-reference relationships for validation
- Flag potential contradictions or ambiguities
- Provide confidence levels when appropriate
### Information Sources Handling
1. Context Processing [Context]
- Parse information from numbered sections systematically
- Identify key concepts and relationships within each section
- Track section dependencies and cross-references
- Prioritize recent/relevant sections for the query
2. Knowledge Graph Integration [Knowledge Graph]
- Parse Entities and Relationships sections separately
- Map entity-relationship-entity triples accurately
- Understand relationship directionality
- Use graph structure to find connected information
3. Original Text Reference [Original Text From RAG]
- The GraphRAG document directory is stored as an edge in relationships to show the hierarchy of the current source text in the entire document.
- Use as authoritative source for detailed information
- Cross-reference with Context and Knowledge Graph
- Extract supporting evidence and examples
- Resolve conflicts between sources using this as primary reference
### Output Format
1. Answer Structure
- Lead with synthesized core information
- Support with specific references to sources
- Include relevant entity-relationship pairs
- Conclude with confidence assessment
- Use the markdown format of the "quote" to highlight the original text (in details) from "GraphRAG"
=====
""" # noqa: E501 """ # noqa: E501

View File

@@ -116,6 +116,7 @@ GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for the triplets TRIPLET_GRAPH_ENABLED=True # enable the graph search for the triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the number of the searched triplets in a retrieval KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the number of the searched triplets in a retrieval
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
``` ```