mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
Support embedding similarity search for GraphRAG (#2200)
Co-authored-by: Appointat <kuda.czk@antgroup.com> Co-authored-by: Lyu Songlin <lyusonglin990625@gmail.com> Co-authored-by: SonglinLyu <111941624+SonglinLyu@users.noreply.github.com>
This commit is contained in:
@@ -162,17 +162,20 @@ VECTOR_STORE_TYPE=Chroma
|
||||
GRAPH_STORE_TYPE=TuGraph
|
||||
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
|
||||
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
|
||||
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE=5
|
||||
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE=0.7
|
||||
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
|
||||
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
|
||||
|
||||
GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
|
||||
TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
|
||||
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
|
||||
SIMILARITY_SEARCH_ENABLED=True # enable the similarity search for entities and chunks
|
||||
|
||||
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
|
||||
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
|
||||
COMMUNITY_SUMMARY_BATCH_SIZE=20 # the batch size of parallel community summary process
|
||||
|
||||
KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE=20 # the batch size of embedding from the text
|
||||
### Chroma vector db config
|
||||
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
|
||||
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
@@ -63,6 +63,9 @@ class Chunk(Document):
|
||||
description="Separator between metadata fields when converting to string.",
|
||||
)
|
||||
retriever: Optional[str] = Field(default=None, description="retriever name")
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="chunk embedding"
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Convert Chunk to dict."""
|
||||
|
@@ -75,6 +75,15 @@ class TuGraphConnector(BaseConnector):
|
||||
"`pip install neo4j`"
|
||||
) from err
|
||||
|
||||
def get_system_info(self) -> Dict:
|
||||
"""Get system info from the TuGraph."""
|
||||
with self._driver.session(database="default") as session:
|
||||
system_info_list = session.run("CALL dbms.system.info()")
|
||||
system_info = {}
|
||||
for info in system_info_list:
|
||||
system_info[info["name"]] = info["value"]
|
||||
return system_info
|
||||
|
||||
def get_table_names(self) -> Iterator[str]:
|
||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
|
@@ -8,10 +8,7 @@ from typing import Any, Dict, List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import (
|
||||
blocking_func_to_async,
|
||||
blocking_func_to_async_no_executor,
|
||||
)
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -176,13 +173,30 @@ class IndexStoreBase(ABC):
|
||||
Return:
|
||||
List[str]: Chunk ids.
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor,
|
||||
self.load_document_with_limit,
|
||||
chunks,
|
||||
max_chunks_once_load,
|
||||
max_threads,
|
||||
chunk_groups = [
|
||||
chunks[i : i + max_chunks_once_load]
|
||||
for i in range(0, len(chunks), max_chunks_once_load)
|
||||
]
|
||||
logger.info(
|
||||
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
|
||||
f"{max_threads} threads."
|
||||
)
|
||||
tasks = []
|
||||
for chunk_group in chunk_groups:
|
||||
tasks.append(self.aload_document(chunk_group))
|
||||
|
||||
import asyncio
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
ids = []
|
||||
loaded_cnt = 0
|
||||
for success_ids in results:
|
||||
ids.extend(success_ids)
|
||||
loaded_cnt += len(success_ids)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
|
||||
return ids
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
|
@@ -4,6 +4,10 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from dbgpt.core.interface.embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,6 +26,25 @@ class TransformerBase:
|
||||
class EmbedderBase(TransformerBase, ABC):
|
||||
"""Embedder base class."""
|
||||
|
||||
def __init__(self, embedding_fn: Optional[Embeddings]):
|
||||
"""Initialize the Embedder."""
|
||||
if not embedding_fn:
|
||||
raise ValueError("Embedding sevice is required.")
|
||||
self._embedding_fn = embedding_fn
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
||||
async def embed(self, text: str) -> List[float]:
|
||||
"""Embed vector from text."""
|
||||
return await self._embedding_fn.aembed_query(text=text)
|
||||
|
||||
@abstractmethod
|
||||
async def batch_embed(
|
||||
self,
|
||||
inputs: List,
|
||||
batch_size: int = 1,
|
||||
) -> List:
|
||||
"""Batch embed vectors from texts."""
|
||||
|
||||
|
||||
class SummarizerBase(TransformerBase, ABC):
|
||||
"""Summarizer base class."""
|
||||
|
67
dbgpt/rag/transformer/graph_embedder.py
Normal file
67
dbgpt/rag/transformer/graph_embedder.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""GraphEmbedder class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.base import EmbedderBase
|
||||
from dbgpt.storage.graph_store.graph import Graph, GraphElemType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEmbedder(EmbedderBase):
|
||||
"""GraphEmbedder class."""
|
||||
|
||||
async def batch_embed(
|
||||
self,
|
||||
inputs: List[Graph],
|
||||
batch_size: int = 1,
|
||||
) -> List[Graph]:
|
||||
"""Embed graph from graphs in batches."""
|
||||
for graph in inputs:
|
||||
texts = []
|
||||
vectors = []
|
||||
|
||||
# Get the text from graph
|
||||
for vertex in graph.vertices():
|
||||
if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value:
|
||||
texts.append(vertex.get_prop("content"))
|
||||
elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value:
|
||||
texts.append(vertex.vid)
|
||||
else:
|
||||
texts.append(" ")
|
||||
|
||||
n_texts = len(texts)
|
||||
|
||||
# Batch embedding
|
||||
for batch_idx in range(0, n_texts, batch_size):
|
||||
start_idx = batch_idx
|
||||
end_idx = min(start_idx + batch_size, n_texts)
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
|
||||
# Create tasks
|
||||
embedding_tasks = [(self.embed(text)) for text in batch_texts]
|
||||
|
||||
# Process embedding in parallel
|
||||
batch_results = await asyncio.gather(
|
||||
*(task for task in embedding_tasks), return_exceptions=True
|
||||
)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, vector in enumerate(batch_results):
|
||||
if isinstance(vector, Exception):
|
||||
raise RuntimeError(f"Failed to embed text{idx}")
|
||||
vectors.append(vector)
|
||||
|
||||
# Push vectors back into Graph
|
||||
for vertex, vector in zip(graph.vertices(), vectors):
|
||||
vertex.set_prop("_embedding", vector)
|
||||
|
||||
return inputs
|
||||
|
||||
def truncate(self):
|
||||
"""Do nothing by default."""
|
||||
|
||||
def drop(self):
|
||||
"""Do nothing by default."""
|
@@ -1,10 +0,0 @@
|
||||
"""Text2Vector class."""
|
||||
import logging
|
||||
|
||||
from dbgpt.rag.transformer.base import EmbedderBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Text2Vector(EmbedderBase):
|
||||
"""Text2Vector class."""
|
50
dbgpt/rag/transformer/text_embedder.py
Normal file
50
dbgpt/rag/transformer/text_embedder.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""TextEmbedder class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.base import EmbedderBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextEmbedder(EmbedderBase):
|
||||
"""TextEmbedder class."""
|
||||
|
||||
async def batch_embed(
|
||||
self,
|
||||
inputs: List[str],
|
||||
batch_size: int = 1,
|
||||
) -> List[List[float]]:
|
||||
"""Embed texts from graphs in batches."""
|
||||
vectors: List[List[float]] = []
|
||||
n_texts = len(inputs)
|
||||
|
||||
# Batch embedding
|
||||
for batch_idx in range(0, n_texts, batch_size):
|
||||
start_idx = batch_idx
|
||||
end_idx = min(start_idx + batch_size, n_texts)
|
||||
batch_texts = inputs[start_idx:end_idx]
|
||||
|
||||
# Create tasks
|
||||
embedding_tasks = [(self.embed(text)) for text in batch_texts]
|
||||
|
||||
# Process embedding in parallel
|
||||
batch_results = await asyncio.gather(
|
||||
*(task for task in embedding_tasks), return_exceptions=False
|
||||
)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, vector in enumerate(batch_results):
|
||||
if isinstance(vector, Exception):
|
||||
raise RuntimeError(f"Failed to embed text{idx}")
|
||||
vectors.append(vector)
|
||||
|
||||
return vectors
|
||||
|
||||
def truncate(self):
|
||||
"""Do nothing by default."""
|
||||
|
||||
def drop(self):
|
||||
"""Do nothing by default."""
|
@@ -27,6 +27,10 @@ class GraphStoreConfig(BaseModel):
|
||||
default=False,
|
||||
description="Enable graph community summary or not.",
|
||||
)
|
||||
enable_similarity_search: bool = Field(
|
||||
default=False,
|
||||
description="Enable similarity search or not.",
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreBase(ABC):
|
||||
@@ -36,6 +40,8 @@ class GraphStoreBase(ABC):
|
||||
"""Initialize graph store."""
|
||||
self._config = config
|
||||
self._conn = None
|
||||
self.enable_summary = config.enable_summary
|
||||
self.enable_similarity_search = config.enable_similarity_search
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> GraphStoreConfig:
|
||||
|
@@ -67,6 +67,14 @@ class TuGraphStoreConfig(GraphStoreConfig):
|
||||
"/dbgpt-tugraph-plugins/tree/master/cpp"
|
||||
),
|
||||
)
|
||||
enable_summary: bool = Field(
|
||||
default=True,
|
||||
description="Enable graph community summary or not.",
|
||||
)
|
||||
enable_similarity_search: bool = Field(
|
||||
default=True,
|
||||
description="Enable the similarity search or not",
|
||||
)
|
||||
|
||||
|
||||
class TuGraphStore(GraphStoreBase):
|
||||
@@ -79,9 +87,15 @@ class TuGraphStore(GraphStoreBase):
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
|
||||
self._enable_summary = (
|
||||
self.enable_summary = (
|
||||
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
|
||||
or config.enable_summary
|
||||
if "GRAPH_COMMUNITY_SUMMARY_ENABLED" in os.environ
|
||||
else config.enable_summary
|
||||
)
|
||||
self.enable_similarity_search = (
|
||||
os.environ["SIMILARITY_SEARCH_ENABLED"].lower() == "true"
|
||||
if "SIMILARITY_SEARCH_ENABLED" in os.environ
|
||||
else config.enable_similarity_search
|
||||
)
|
||||
self._plugin_names = (
|
||||
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
|
||||
from typing import AsyncGenerator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
@@ -176,18 +176,88 @@ class GraphStoreAdapter(ABC):
|
||||
"""Check if the label exists in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
def explore_trigraph(
|
||||
self,
|
||||
subs: List[str],
|
||||
subs: Union[List[str], List[List[float]]],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (Union[List[str], List[List[float]]): The list of the subjects
|
||||
(keywords or embedding vectors).
|
||||
topk (Optional[int]): The number of the top similar entities.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried entities.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The triplet graph that includes the entities and the relations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def explore_docgraph_with_entities(
|
||||
self,
|
||||
subs: List[str],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (List[str]): The list of the entities.
|
||||
topk (Optional[int]): The number of the top similar chunks.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried chunks.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The document graph that includes the leaf chunks that connect
|
||||
to the entities, the chains from documents to the leaf chunks, and the
|
||||
chain from documents to chunks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def explore_docgraph_without_entities(
|
||||
self,
|
||||
subs: Union[List[str], List[List[float]]],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (Union[List[str], List[List[float]]): The list of the subjects
|
||||
(keywords or embedding vectors).
|
||||
topk (Optional[int]): The number of the top similar chunks.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried chunks.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The document graph that includes the chains from documents
|
||||
to chunks that contain the subs (keywords) or similar chunks
|
||||
(embedding vectors).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
|
@@ -56,7 +56,9 @@ class CommunityStore:
|
||||
return None
|
||||
|
||||
graph = community.data.format()
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
community.summary = (
|
||||
await self._community_summarizer.summarize(graph=graph) or ""
|
||||
)
|
||||
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
|
||||
return community
|
||||
|
||||
|
@@ -29,8 +29,7 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def __init__(self, enable_summary: bool = False):
|
||||
"""Initialize MemGraph Community Store Adapter."""
|
||||
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
self._enable_summary = enable_summary
|
||||
self._graph_store: MemoryGraphStore = MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
|
||||
super().__init__(self._graph_store)
|
||||
|
||||
@@ -39,7 +38,7 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
[]
|
||||
return []
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
@@ -197,7 +196,7 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("Memory graph store does not have label")
|
||||
|
||||
def explore(
|
||||
self,
|
||||
@@ -215,8 +214,8 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
pass
|
||||
raise NotImplementedError("Memory graph store does not support query")
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
pass
|
||||
raise NotImplementedError("Memory graph store does not support stream query")
|
||||
|
@@ -2,17 +2,9 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
@@ -136,6 +128,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
enable_similarity_search = self.graph_store.enable_similarity_search
|
||||
entity_list = [
|
||||
{
|
||||
"id": self._escape_quotes(entity.vid),
|
||||
@@ -145,6 +138,11 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
**(
|
||||
{"_embedding": entity.get_prop("_embedding")}
|
||||
if enable_similarity_search
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for entity in entities
|
||||
]
|
||||
@@ -153,6 +151,29 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
f'"{GraphElemType.ENTITY.value}", '
|
||||
f"[{self._convert_dict_to_str(entity_list)}])"
|
||||
)
|
||||
|
||||
# If similarity search enabled, then ready to create vector index
|
||||
if enable_similarity_search:
|
||||
# Check wheather the vector index exist
|
||||
check_entity_vector_query = (
|
||||
"CALL db.showVertexVectorIndex() "
|
||||
"YIELD label_name, field_name "
|
||||
f"WHERE label_name = '{GraphElemType.ENTITY.value}' "
|
||||
"AND field_name = '_embedding' "
|
||||
"RETURN label_name"
|
||||
)
|
||||
# If not exist, then create vector index
|
||||
if self.query(check_entity_vector_query).vertex_count == 0:
|
||||
# Get the dimension
|
||||
dimension = len(entity_list[0].get("_embedding", []))
|
||||
# Then create index
|
||||
create_vector_index_query = (
|
||||
"CALL db.addVertexVectorIndex("
|
||||
f'"{GraphElemType.ENTITY.value}", "_embedding", '
|
||||
f"{{dimension: {dimension}}})"
|
||||
)
|
||||
self.graph_store.conn.run(query=create_vector_index_query)
|
||||
|
||||
self.graph_store.conn.run(query=entity_query)
|
||||
|
||||
def upsert_edge(
|
||||
@@ -178,26 +199,62 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
|
||||
"""Upsert chunks."""
|
||||
chunk_list = [
|
||||
enable_similarity_search = self.graph_store.enable_similarity_search
|
||||
chunk_list: List[Dict[str, Union[str, List[float]]]] = [
|
||||
{
|
||||
"id": self._escape_quotes(chunk.chunk_id),
|
||||
"name": self._escape_quotes(chunk.chunk_name),
|
||||
"content": self._escape_quotes(chunk.content),
|
||||
**(
|
||||
{"_embedding": chunk.embedding}
|
||||
if enable_similarity_search and chunk.embedding
|
||||
else {}
|
||||
),
|
||||
}
|
||||
if isinstance(chunk, ParagraphChunk)
|
||||
else {
|
||||
"id": self._escape_quotes(chunk.vid),
|
||||
"name": self._escape_quotes(chunk.name),
|
||||
"content": self._escape_quotes(chunk.get_prop("content")),
|
||||
**(
|
||||
{"_embedding": chunk.get_prop("_embedding")}
|
||||
if enable_similarity_search
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
if len(chunk_list) == 0:
|
||||
return
|
||||
|
||||
chunk_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.CHUNK.value}", '
|
||||
f"[{self._convert_dict_to_str(chunk_list)}])"
|
||||
)
|
||||
|
||||
# If similarity search enabled, then ready to create vector index
|
||||
if enable_similarity_search:
|
||||
# Check wheather the vector index exist
|
||||
check_chunk_vector_query = (
|
||||
"CALL db.showVertexVectorIndex() "
|
||||
"YIELD label_name, field_name "
|
||||
f"WHERE label_name = '{GraphElemType.CHUNK.value}' "
|
||||
"AND field_name = '_embedding' "
|
||||
"RETURN label_name"
|
||||
)
|
||||
# If not exist, then create vector index
|
||||
if self.query(check_chunk_vector_query).vertex_count == 0:
|
||||
# Get the dimension
|
||||
dimension = len(chunk_list[0].get("_embedding", []))
|
||||
# Then create index
|
||||
create_vector_index_query = (
|
||||
"CALL db.addVertexVectorIndex("
|
||||
f'"{GraphElemType.CHUNK.value}", "_embedding", '
|
||||
f"{{dimension: {dimension}}})"
|
||||
)
|
||||
self.graph_store.conn.run(query=create_vector_index_query)
|
||||
|
||||
self.graph_store.conn.run(query=chunk_query)
|
||||
|
||||
def upsert_documents(
|
||||
@@ -355,6 +412,18 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
if not self.graph_store.conn.create_graph(graph_name=graph_name):
|
||||
return
|
||||
|
||||
# Compatibility check
|
||||
enable_similarity_search = self.graph_store.enable_similarity_search
|
||||
|
||||
dbms_system_info = self.graph_store.conn.get_system_info()
|
||||
lgraph_version = dbms_system_info["lgraph_version"]
|
||||
similarity_search_compatible = Version(lgraph_version) >= Version("4.5.1")
|
||||
|
||||
if enable_similarity_search and not similarity_search_compatible:
|
||||
raise Exception(
|
||||
"TuGraph 4.5.0 and below does not support similarity search."
|
||||
)
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_property_schema(
|
||||
name: str,
|
||||
@@ -405,6 +474,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("content", "STRING", True, True),
|
||||
]
|
||||
if enable_similarity_search:
|
||||
chunk_proerties.append(
|
||||
_format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False)
|
||||
)
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
)
|
||||
@@ -416,6 +489,12 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
if enable_similarity_search:
|
||||
vertex_proerties.append(
|
||||
_format_graph_property_schema(
|
||||
"_embedding", "FLOAT_VECTOR", True, False
|
||||
),
|
||||
)
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
)
|
||||
@@ -451,7 +530,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
)
|
||||
|
||||
if self.graph_store._enable_summary:
|
||||
if self.graph_store.enable_summary:
|
||||
self.graph_store._upload_plugin()
|
||||
|
||||
def create_graph_label(
|
||||
@@ -465,7 +544,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
if graph_elem_type.is_vertex(): # vertex
|
||||
data = json.dumps(
|
||||
vertex_meta = json.dumps(
|
||||
{
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
@@ -473,7 +552,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"properties": graph_properties,
|
||||
}
|
||||
)
|
||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||
gql = f"""CALL db.createVertexLabelByJson('{vertex_meta}')"""
|
||||
else: # edge
|
||||
|
||||
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
|
||||
@@ -498,7 +577,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
else:
|
||||
raise ValueError("Invalid graph element type.")
|
||||
|
||||
data = json.dumps(
|
||||
edge_meta = json.dumps(
|
||||
{
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
@@ -506,7 +585,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"properties": graph_properties,
|
||||
}
|
||||
)
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{edge_meta}')"""
|
||||
|
||||
self.graph_store.conn.run(gql)
|
||||
|
||||
@@ -529,18 +608,31 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
return graph_elem_type.value in tables
|
||||
|
||||
def explore(
|
||||
def explore_trigraph(
|
||||
self,
|
||||
subs: List[str],
|
||||
subs: Union[List[str], List[List[float]]],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (Union[List[str], List[List[float]]): The list of the subjects
|
||||
(keywords or embedding vectors).
|
||||
topk (Optional[int]): The number of the top similar entities.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried entities.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The triplet graph that includes the entities and the relations.
|
||||
"""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
@@ -553,110 +645,247 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
if search_scope == "knowledge_graph":
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
|
||||
f"{rel}(m:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
|
||||
f"RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query=query, white_list=["description"])
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
else:
|
||||
# If there exists the entities in the graph, return the graph that
|
||||
# includes the leaf chunks that connect to the entities, the chains from
|
||||
# documents to the leaf chunks, and the chain from documents to chunks;
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk -> (entity)
|
||||
#
|
||||
# If not, return the graph that includes the chains from documents to chunks
|
||||
# that contain the subs (keywords).
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk (that contains the subs)
|
||||
#
|
||||
# And only the leaf chunks contain the content, and the other chunks do not
|
||||
# contain any properties except the id, name.
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
|
||||
graph = MemoryGraph()
|
||||
enable_similarity_search = self.graph_store.enable_similarity_search
|
||||
|
||||
# Check if the entities exist in the graph
|
||||
check_entity_query = (
|
||||
f"MATCH (n:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
|
||||
"RETURN n"
|
||||
if enable_similarity_search:
|
||||
# If enable similarity search, using knn-search to get the id
|
||||
similar_entities = []
|
||||
# Get the vector from vectors
|
||||
# Then do knn-search for each vectors by using TuGraph
|
||||
for vector in subs:
|
||||
similarity_retrieval_query = (
|
||||
"CALL db.vertexVectorKnnSearch("
|
||||
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, "
|
||||
f"{{top_k:{topk}}}) YIELD node "
|
||||
f"WHERE node.distance < {score_threshold} "
|
||||
"RETURN node.id AS id;"
|
||||
)
|
||||
# Merge the result for each knn-search result
|
||||
similar_entities.extend(
|
||||
self.graph_store.conn.run(query=similarity_retrieval_query)
|
||||
)
|
||||
# Get the id from the retrieved entities
|
||||
ids = [(record["id"]) for record in similar_entities]
|
||||
conditional_statement = f"WHERE n.id IN {ids} "
|
||||
else:
|
||||
conditional_statement = (
|
||||
f"WHERE n.id IN {[self._escape_quotes(str(sub)) for sub in subs]} "
|
||||
)
|
||||
if self.query(check_entity_query):
|
||||
# Query the leaf chunks in the chain from documents to chunks
|
||||
leaf_chunk_query = (
|
||||
f"MATCH p=(n:{GraphElemType.CHUNK.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}]->"
|
||||
f"(m:{GraphElemType.ENTITY.value})"
|
||||
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
|
||||
f"RETURN n"
|
||||
)
|
||||
graph_of_leaf_chunks = self.query(
|
||||
query=leaf_chunk_query, white_list=["content"]
|
||||
)
|
||||
|
||||
# Query the chain from documents to chunks,
|
||||
# document -> chunk -> ... -> leaf_chunks
|
||||
chunk_names = [
|
||||
self._escape_quotes(vertex.name)
|
||||
for vertex in graph_of_leaf_chunks.vertices()
|
||||
# Multi-hop search
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
|
||||
f"{rel}(m:{GraphElemType.ENTITY.value}) "
|
||||
f"{conditional_statement}"
|
||||
f"RETURN p {limit_string}"
|
||||
)
|
||||
|
||||
return self.query(query=query, white_list=["description"])
|
||||
|
||||
def explore_docgraph_with_entities(
|
||||
self,
|
||||
subs: List[str],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (List[str]): The list of the entities.
|
||||
topk (Optional[int]): The number of the top similar chunks.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried chunks.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The document graph that includes the leaf chunks that
|
||||
connect to the entities, the chains from documents to the leaf chunks,
|
||||
and the chain from documents to chunks.
|
||||
"""
|
||||
if len(subs) == 0:
|
||||
return MemoryGraph()
|
||||
|
||||
if depth <= 0:
|
||||
depth = 3
|
||||
depth_string = f"1..{depth}"
|
||||
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
# If there exists the entities in the graph, return the graph that
|
||||
# includes the leaf chunks that connect to the entities, the chains from
|
||||
# documents to the leaf chunks, and the chain from documents to chunks;
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk -> (entity)
|
||||
#
|
||||
# If not, return the graph that includes the chains from documents to chunks
|
||||
# that contain the subs (keywords).
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk (that contains the subs)
|
||||
#
|
||||
# And only the leaf chunks contain the content, and the other chunks do not
|
||||
# contain any properties except the id, name.
|
||||
|
||||
graph = MemoryGraph()
|
||||
|
||||
# Query the leaf chunks in the chain from documents to chunks
|
||||
conditional_statement = (
|
||||
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
|
||||
)
|
||||
leaf_chunk_query = (
|
||||
f"MATCH p=(n:{GraphElemType.CHUNK.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}]->"
|
||||
f"(m:{GraphElemType.ENTITY.value})"
|
||||
f"{conditional_statement} "
|
||||
"RETURN n"
|
||||
)
|
||||
graph_of_leaf_chunks = self.query(
|
||||
query=leaf_chunk_query, white_list=["content"]
|
||||
)
|
||||
|
||||
# Query the chain from documents to chunks,
|
||||
# document -> chunk -> ... -> leaf_chunks
|
||||
chunk_names = [
|
||||
self._escape_quotes(vertex.name)
|
||||
for vertex in graph_of_leaf_chunks.vertices()
|
||||
]
|
||||
chain_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE m.name IN {chunk_names} "
|
||||
"RETURN p"
|
||||
)
|
||||
# Filter all the properties by with_list
|
||||
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
|
||||
|
||||
# The number of leaf chunks caompared to the `limit`
|
||||
if not limit or len(chunk_names) <= limit:
|
||||
graph.upsert_graph(graph_of_leaf_chunks)
|
||||
else:
|
||||
limited_leaf_chunk_query = leaf_chunk_query + f" {limit_string}"
|
||||
graph.upsert_graph(
|
||||
self.query(query=limited_leaf_chunk_query, white_list=["content"])
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def explore_docgraph_without_entities(
|
||||
self,
|
||||
subs: Union[List[str], List[List[float]]],
|
||||
topk: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int = 3,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth.
|
||||
|
||||
Args:
|
||||
subs (Union[List[str], List[List[float]]): The list of the subjects
|
||||
(keywords or embedding vectors).
|
||||
topk (Optional[int]): The number of the top similar chunks.
|
||||
score_threshold (Optional[float]): The threshold of the similarity score.
|
||||
direct (Direction): The direction of the graph that will be explored.
|
||||
depth (int): The depth of the graph that will be explored.
|
||||
fan (Optional[int]): Not used.
|
||||
limit (Optional[int]): The limit number of the queried chunks.
|
||||
|
||||
Returns:
|
||||
MemoryGraph: The document graph that includes the chains from documents
|
||||
to chunks that contain the subs (keywords) or similar chunks
|
||||
(embedding vectors).
|
||||
"""
|
||||
if len(subs) == 0:
|
||||
return MemoryGraph()
|
||||
|
||||
if depth <= 0:
|
||||
depth = 3
|
||||
depth_string = f"1..{depth}"
|
||||
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
enable_similarity_search = self.graph_store.enable_similarity_search
|
||||
|
||||
# If there exists the entities in the graph, return the graph that
|
||||
# includes the leaf chunks that connect to the entities, the chains from
|
||||
# documents to the leaf chunks, and the chain from documents to chunks;
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk -> (entity)
|
||||
#
|
||||
# If not, return the graph that includes the chains from documents to chunks
|
||||
# that contain the subs (keywords).
|
||||
# document -> chunk -> chunk -> ... -> leaf chunk (that contains the subs)
|
||||
#
|
||||
# And only the leaf chunks contain the content, and the other chunks do not
|
||||
# contain any properties except the id, name.
|
||||
|
||||
graph = MemoryGraph()
|
||||
|
||||
if enable_similarity_search:
|
||||
similar_chunks = []
|
||||
for vector in subs:
|
||||
similarity_retrieval_query = (
|
||||
"CALL db.vertexVectorKnnSearch("
|
||||
f"'{GraphElemType.CHUNK.value}','_embedding', {vector}, "
|
||||
f"{{top_k:{topk}}}) YIELD node "
|
||||
f"WHERE node.distance < {score_threshold} "
|
||||
"RETURN node.name AS name"
|
||||
)
|
||||
similar_chunks.extend(
|
||||
self.graph_store.conn.run(query=similarity_retrieval_query)
|
||||
)
|
||||
names = [(record["name"]) for record in similar_chunks]
|
||||
_subs_condition = " OR ".join(
|
||||
[f"m.content CONTAINS '{name}'" for name in names]
|
||||
)
|
||||
else:
|
||||
_subs_condition = " OR ".join(
|
||||
[
|
||||
f"m.content CONTAINS '{self._escape_quotes(str(sub))}'"
|
||||
for sub in subs
|
||||
]
|
||||
chain_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE m.name IN {chunk_names} "
|
||||
"RETURN p"
|
||||
)
|
||||
# Filter all the properties by with_list
|
||||
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
|
||||
)
|
||||
|
||||
# The number of leaf chunks caompared to the `limit`
|
||||
if not limit or len(chunk_names) <= limit:
|
||||
graph.upsert_graph(graph_of_leaf_chunks)
|
||||
else:
|
||||
limited_leaf_chunk_query = leaf_chunk_query + f" {limit_string}"
|
||||
graph.upsert_graph(
|
||||
self.query(
|
||||
query=limited_leaf_chunk_query, white_list=["content"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
_subs_condition = " OR ".join(
|
||||
[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs]
|
||||
)
|
||||
# Query the chain from documents to chunks,
|
||||
# document -> chunk -> chunk -> chunk -> ... -> chunk
|
||||
chain_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE {_subs_condition}"
|
||||
"RETURN p"
|
||||
)
|
||||
# Filter all the properties by with_list
|
||||
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
|
||||
|
||||
# Query the chain from documents to chunks,
|
||||
# document -> chunk -> chunk -> chunk -> ... -> chunk
|
||||
chain_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE {_subs_condition}"
|
||||
"RETURN p"
|
||||
)
|
||||
# Filter all the properties by with_list
|
||||
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))
|
||||
# Query the leaf chunks in the chain from documents to chunks
|
||||
leaf_chunk_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE {_subs_condition}"
|
||||
f"RETURN m {limit_string}"
|
||||
)
|
||||
graph.upsert_graph(self.query(query=leaf_chunk_query, white_list=["content"]))
|
||||
|
||||
# Query the leaf chunks in the chain from documents to chunks
|
||||
leaf_chunk_query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
|
||||
f"(m:{GraphElemType.CHUNK.value})"
|
||||
f"WHERE {_subs_condition}"
|
||||
f"RETURN m {limit_string}"
|
||||
)
|
||||
graph.upsert_graph(
|
||||
self.query(query=leaf_chunk_query, white_list=["content"])
|
||||
)
|
||||
|
||||
return graph
|
||||
return graph
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph.
|
||||
|
@@ -3,13 +3,15 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.rag.transformer.graph_embedder import GraphEmbedder
|
||||
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
|
||||
from dbgpt.rag.transformer.text_embedder import TextEmbedder
|
||||
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
|
||||
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
@@ -179,7 +181,6 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
default=True,
|
||||
description="Enable the graph search for documents and chunks",
|
||||
)
|
||||
|
||||
knowledge_graph_chunk_search_top_size: int = Field(
|
||||
default=5,
|
||||
description="Top size of knowledge graph chunk search",
|
||||
@@ -192,6 +193,18 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
default=20,
|
||||
description="Batch size of parallel community building process",
|
||||
)
|
||||
knowledge_graph_embedding_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of triplets embedding from the text",
|
||||
)
|
||||
similarity_search_topk: int = Field(
|
||||
default=5,
|
||||
description="Topk of similarity search",
|
||||
)
|
||||
similarity_search_score_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="Recall score of similarity search",
|
||||
)
|
||||
|
||||
|
||||
@register_resource(
|
||||
@@ -263,12 +276,30 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
config.knowledge_graph_extraction_batch_size,
|
||||
)
|
||||
)
|
||||
self._triplet_embedding_batch_size = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
|
||||
config.knowledge_graph_embedding_batch_size,
|
||||
)
|
||||
)
|
||||
self._community_summary_batch_size = int(
|
||||
os.getenv(
|
||||
"COMMUNITY_SUMMARY_BATCH_SIZE",
|
||||
config.community_summary_batch_size,
|
||||
)
|
||||
)
|
||||
self._similarity_search_topk = int(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
|
||||
config.similarity_search_topk,
|
||||
)
|
||||
)
|
||||
self._similarity_search_score_threshold = float(
|
||||
os.getenv(
|
||||
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE",
|
||||
config.similarity_search_score_threshold,
|
||||
)
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
@@ -290,6 +321,9 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
),
|
||||
)
|
||||
|
||||
self._graph_embedder = GraphEmbedder(self._config.embedding_fn)
|
||||
self._text_embedder = TextEmbedder(self._config.embedding_fn)
|
||||
|
||||
def community_store_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
@@ -339,6 +373,18 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
]
|
||||
documment_chunk, paragraph_chunks = self._load_chunks(_chunks)
|
||||
|
||||
if self._graph_store.enable_similarity_search:
|
||||
# Add embeddings from chunk content
|
||||
texts: List[str] = [chunk.content for chunk in paragraph_chunks]
|
||||
|
||||
embeddings = await self._text_embedder.batch_embed(
|
||||
inputs=texts,
|
||||
batch_size=self._triplet_embedding_batch_size,
|
||||
)
|
||||
|
||||
for idx, chunk in enumerate(paragraph_chunks):
|
||||
chunk.embedding = embeddings[idx]
|
||||
|
||||
# upsert the document and chunks vertices
|
||||
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
|
||||
self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks))
|
||||
@@ -376,6 +422,15 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
if not graphs_list:
|
||||
raise ValueError("No graphs extracted from the chunks")
|
||||
|
||||
# If enable the similarity search, add the embedding to the graphs
|
||||
if self._graph_store.enable_similarity_search:
|
||||
for idx, graphs in enumerate(graphs_list):
|
||||
embeded_graphs = await self._graph_embedder.batch_embed(
|
||||
inputs=graphs,
|
||||
batch_size=self._triplet_embedding_batch_size,
|
||||
)
|
||||
graphs_list[idx] = embeded_graphs
|
||||
|
||||
# Upsert the graphs into the graph store
|
||||
for idx, graphs in enumerate(graphs_list):
|
||||
for graph in graphs:
|
||||
@@ -465,7 +520,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
]
|
||||
context = "\n".join(summaries) if summaries else ""
|
||||
|
||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||
enable_similarity_search = self._graph_store.enable_similarity_search
|
||||
|
||||
subgraph = None
|
||||
subgraph_for_doc = None
|
||||
|
||||
@@ -473,28 +529,68 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
triplet_graph_enabled = self._triplet_graph_enabled
|
||||
document_graph_enabled = self._document_graph_enabled
|
||||
|
||||
# Using subs to transfer keywords or embeddings
|
||||
# Using subs to transfer keywords
|
||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||
|
||||
# If enable similarity search, using subs to transfer embeddings
|
||||
subs: Union[List[str], List[List[float]]]
|
||||
if enable_similarity_search:
|
||||
# Embedding the question
|
||||
vector = await self._text_embedder.embed(text)
|
||||
# Embedding the keywords
|
||||
vectors = await self._text_embedder.batch_embed(
|
||||
keywords, batch_size=self._triplet_embedding_batch_size
|
||||
)
|
||||
# Using the embeddings of keywords and question
|
||||
vectors.append(vector)
|
||||
subs = vectors
|
||||
else:
|
||||
subs = keywords
|
||||
|
||||
# If enable triplet graph, using subs to search enetities
|
||||
# subs -> enetities
|
||||
if triplet_graph_enabled:
|
||||
subgraph = self._graph_store_apdater.explore(
|
||||
subs=keywords, limit=topk, search_scope="knowledge_graph"
|
||||
subgraph = self._graph_store_apdater.explore_trigraph(
|
||||
subs=subs,
|
||||
limit=topk,
|
||||
topk=self._similarity_search_topk,
|
||||
score_threshold=self._similarity_search_score_threshold,
|
||||
)
|
||||
|
||||
if document_graph_enabled:
|
||||
# If enabled document graph
|
||||
if document_graph_enabled:
|
||||
# If not enable triplet graph or subgraph is None
|
||||
# Using subs to search chunks
|
||||
# subs -> chunks -> doc
|
||||
if subgraph is None or subgraph.vertex_count == 0:
|
||||
subgraph_for_doc = (
|
||||
self._graph_store_apdater.explore_docgraph_without_entities(
|
||||
subs=subs,
|
||||
topk=self._similarity_search_topk,
|
||||
score_threshold=self._similarity_search_score_threshold,
|
||||
limit=self._knowledge_graph_chunk_search_top_size,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If there are searched entities
|
||||
# Append the vids of entities
|
||||
# VID is the KEYWORD which stores in entity
|
||||
keywords_for_document_graph = keywords
|
||||
for vertex in subgraph.vertices():
|
||||
keywords_for_document_graph.append(vertex.name)
|
||||
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords_for_document_graph,
|
||||
limit=self._knowledge_graph_chunk_search_top_size,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
else:
|
||||
if document_graph_enabled:
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords,
|
||||
limit=self._knowledge_graph_chunk_search_top_size,
|
||||
search_scope="document_graph",
|
||||
# Using the vids to search chunks and doc
|
||||
# entities -> chunks -> doc
|
||||
subgraph_for_doc = (
|
||||
self._graph_store_apdater.explore_docgraph_with_entities(
|
||||
subs=keywords_for_document_graph,
|
||||
topk=self._similarity_search_topk,
|
||||
score_threshold=self._similarity_search_score_threshold,
|
||||
limit=self._knowledge_graph_chunk_search_top_size,
|
||||
)
|
||||
)
|
||||
|
||||
knowledge_graph_str = subgraph.format() if subgraph else ""
|
||||
knowledge_graph_for_doc_str = (
|
||||
subgraph_for_doc.format() if subgraph_for_doc else ""
|
||||
@@ -522,6 +618,10 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
self._keyword_extractor.truncate()
|
||||
logger.info("Truncate triplet extractor")
|
||||
self._graph_extractor.truncate()
|
||||
logger.info("Truncate graph embedder")
|
||||
self._graph_embedder.truncate()
|
||||
logger.info("Truncate text embedder")
|
||||
self._text_embedder.truncate()
|
||||
return [self._config.name]
|
||||
|
||||
def delete_vector_name(self, index_name: str):
|
||||
@@ -535,6 +635,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
logger.info("Drop triplet extractor")
|
||||
self._graph_extractor.drop()
|
||||
|
||||
logger.info("Drop graph embedder")
|
||||
self._graph_embedder.drop()
|
||||
|
||||
logger.info("Drop text embedder")
|
||||
self._text_embedder.drop()
|
||||
|
||||
|
||||
HYBRID_SEARCH_PT = """
|
||||
=====
|
||||
|
@@ -221,7 +221,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
# extract keywords and explore graph store
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store_apdater.explore(keywords, limit=topk).format()
|
||||
subgraph = self._graph_store_apdater.explore_trigraph(
|
||||
keywords, limit=topk
|
||||
).format()
|
||||
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
|
||||
@@ -232,7 +234,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
"The following entities and relationships provided after "
|
||||
"[Subgraph] are retrieved from the knowledge graph "
|
||||
"based on the keywords:\n"
|
||||
f"\"{','.join(keywords)}\".\n"
|
||||
f'"{",".join(keywords)}".\n'
|
||||
"---------------------\n"
|
||||
"The following examples after [Entities] and [Relationships] that "
|
||||
"can help you understand the data format of the knowledge graph, "
|
||||
|
@@ -17,7 +17,7 @@ pip install "dbgpt[graph_rag]>=0.6.1"
|
||||
|
||||
To store the knowledge in graph, we need an graph database, [TuGraph](https://github.com/TuGraph-family/tugraph-db) is the first graph database supported by DB-GPT.
|
||||
|
||||
Visit github repository of TuGraph to view [Quick Start](https://tugraph-db.readthedocs.io/zh-cn/latest/3.quick-start/1.preparation.html#id5) document, follow the instructions to pull the TuGraph database docker image (latest / version >= 4.5.0) and launch it.
|
||||
Visit github repository of TuGraph to view [Quick Start](https://tugraph-db.readthedocs.io/zh-cn/latest/3.quick-start/1.preparation.html#id5) document, follow the instructions to pull the TuGraph database docker image (latest / version >= 4.5.1) and launch it.
|
||||
|
||||
```
|
||||
docker pull tugraph/tugraph-runtime-centos7:4.5.1
|
||||
@@ -368,7 +368,7 @@ DB-GPT社区与TuGraph社区的比较
|
||||
总体而言,DB-GPT社区和TuGraph社区在社区贡献、生态系统和开发者参与等方面各具特色。DB-GPT社区更侧重于AI应用的多样性和组织间的合作,而TuGraph社区则专注于图数据的高效管理和分析。两者的共同点在于都强调了开源和社区合作的重要性,推动了各自领域的技术进步和应用发展。
|
||||
```
|
||||
|
||||
### Latest Updates
|
||||
### Retrieval Of Document Structure
|
||||
|
||||
In version 0.6.1 of DB-GPT, we have added a new feature:
|
||||
- Retrieval of triplets with the **retrieval of document structure**
|
||||
@@ -397,4 +397,79 @@ We decompose standard format files (currently best support for Markdown files) i
|
||||
|
||||
What is the next?
|
||||
|
||||
We aim to construct a more complex Graph that covers more comprehensive information to support more sophisticated retrieval algorithms in our GraphRAG.
|
||||
We aim to construct a more complex Graph that covers more comprehensive information to support more sophisticated retrieval algorithms in our GraphRAG.
|
||||
|
||||
|
||||
### Similarity Search in GraphRAG:
|
||||
|
||||
In the latest version of DB-GPT, we have implemented a new feature:
|
||||
|
||||
- **Similarity search** for GraphRAG retrieval
|
||||
|
||||
#### How to use?
|
||||
|
||||
Use TuGraph 4.5.1 and above.
|
||||
|
||||
Set the variables below in the `.env` file to enable similarity search in DB-GPT.
|
||||
|
||||
```
|
||||
SIMILARITY_SEARCH_ENABLED=True # enable the similarity search for entities and chunks
|
||||
KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE=20 # the batch size of embedding from the text
|
||||
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE=5 # set the topk of the vector similarity search
|
||||
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE=0.3 # set the reacall score of the vector similarity search
|
||||
```
|
||||
|
||||
Additionally, you need to choose an embedding model in the `.env` file
|
||||
|
||||
```
|
||||
## Openai embedding model, See dbgpt/model/parameter.py
|
||||
# EMBEDDING_MODEL=proxy_openai
|
||||
# proxy_openai_proxy_server_url=https://api.openai.com/v1
|
||||
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||
|
||||
|
||||
## qwen embedding model, See dbgpt/model/parameter.py
|
||||
# EMBEDDING_MODEL=proxy_tongyi
|
||||
# proxy_tongyi_proxy_backend=text-embedding-v1
|
||||
# proxy_tongyi_proxy_api_key={your-api-key}
|
||||
|
||||
## qianfan embedding model, See dbgpt/model/parameter.py
|
||||
#EMBEDDING_MODEL=proxy_qianfan
|
||||
#proxy_qianfan_proxy_backend=bge-large-zh
|
||||
#proxy_qianfan_proxy_api_key={your-api-key}
|
||||
#proxy_qianfan_proxy_api_secret={your-secret-key}
|
||||
```
|
||||
|
||||
#### Why to use?
|
||||
|
||||
TuGraph now offers comprehensive vector capabilities, including vector storage, indexing, and similarity search functionalities. These features enable GraphRAG to achieve superior retrieval performance compared to traditional keyword-based approaches.
|
||||
|
||||
|
||||
To leverage these capabilities, we've introduced an `_embedding` field in both entity and chunk objects to store embedding data, enabling similarity search to identify the most relevant results for a given query.
|
||||
|
||||
#### Comparison of Similarity Search Results
|
||||
|
||||
Given identical documents and questions in the same environment, the results of the keyword mode are as follows:
|
||||
|
||||
<p align="left">
|
||||
<img src={'/img/chat_knowledge/graph_rag/comparison_result_for_keywords.png'} width="1000px"/>
|
||||
</p>
|
||||
|
||||
The results of the similarity search mode are as follows:
|
||||
|
||||
<p align="left">
|
||||
<img src={'/img/chat_knowledge/graph_rag/comparison_result_for_similarity_search.png'} width="1000px"/>
|
||||
</p>
|
||||
|
||||
Compared to the keyword search method, the similarity search method can cover more comprehensive information. For instance, when dealing with the term 清北大学 in the keyword search mode, it is hard to extract useful keywords. However, the similarity search mode can identify similar words, enabling it to retrieve relevant information related to Tsinghua University and thus include it in the search results.
|
||||
|
||||
This implies that in scenarios where queries are imprecise, the similarity search approach can retrieve more pertinent information compared to keyword-based search patterns.
|
||||
|
||||
Furthermore, as shown in the following figure, compared to RAG, GraphRAG with similarity search can obtain more relevant information, ensuring richer answers.
|
||||
|
||||
<p align="left">
|
||||
<img src={'/img/chat_knowledge/graph_rag/comparison_with_rag.png'} width="1000px"/>
|
||||
</p>
|
||||
|
||||
In conclusion, enabling similarity search in GraphRAG significantly expands the scope and relevance of its responses.
|
||||
|
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_result_for_keywords.png
vendored
Normal file
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_result_for_keywords.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 72 KiB |
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_result_for_similarity_search.png
vendored
Normal file
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_result_for_similarity_search.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 75 KiB |
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_with_rag.png
vendored
Normal file
BIN
docs/static/img/chat_knowledge/graph_rag/comparison_with_rag.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 176 KiB |
Reference in New Issue
Block a user