GraphRAG: add new feature and bugfix (#2373)

Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com>
This commit is contained in:
SonglinLyu
2025-02-27 21:56:19 +08:00
committed by GitHub
parent 1e0674c54d
commit 55daa31dd9
9 changed files with 115 additions and 31 deletions

View File

@@ -50,26 +50,6 @@ class SystemParameters:
)
@dataclass
class StorageVectorConfig(BaseParameters):
type: str = field(
default="Chroma",
metadata={
"help": _("default vector type"),
},
)
@dataclass
class StorageGraphConfig(BaseParameters):
type: str = field(
default="TuGraph",
metadata={
"help": _("default graph type"),
},
)
@dataclass
class StorageConfig(BaseParameters):
vector: VectorStoreConfig = field(

View File

@@ -5,7 +5,10 @@ import os
from typing import List, Tuple, Union
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
from dbgpt_ext.rag.transformer.local_text2gql import LocalText2GQL
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
from ...transformer.text_embedder import TextEmbedder
from .base import GraphRetrieverBase
@@ -81,10 +84,25 @@ class GraphRetriever(GraphRetrieverBase):
if "TEXT_SEARCH_ENABLED" in os.environ
else config.enable_text_search
)
text2gql_model_enabled = (
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
if "TEXT2GQL_MODEL_ENABLED" in os.environ
else config.text2gql_model_enabled
)
text2gql_model_name = os.getenv(
"TEXT2GQL_MODEL_NAME",
config.text2gql_model_name,
)
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
self._text_embedder = TextEmbedder(config.embedding_fn)
intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
if text2gql_model_enabled:
text2gql = LocalText2GQL(text2gql_model_name)
else:
text2gql = Text2GQL(llm_client, model_name)
self._keyword_based_graph_retriever = KeywordBasedGraphRetriever(
graph_store_adapter, triplet_topk
)
@@ -95,7 +113,10 @@ class GraphRetriever(GraphRetrieverBase):
similarity_search_score_threshold,
)
self._text_based_graph_retriever = TextBasedGraphRetriever(
graph_store_adapter, triplet_topk, llm_client, model_name
graph_store_adapter,
triplet_topk,
intent_interpreter,
text2gql,
)
self._document_graph_retriever = DocumentGraphRetriever(
graph_store_adapter,

View File

@@ -4,10 +4,8 @@ import json
import logging
from typing import Dict, List, Tuple, Union
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
from dbgpt_ext.rag.retriever.graph_retriever.base import GraphRetrieverBase
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
logger = logging.getLogger(__name__)
@@ -15,12 +13,18 @@ logger = logging.getLogger(__name__)
class TextBasedGraphRetriever(GraphRetrieverBase):
"""Text Based Graph Retriever class."""
def __init__(self, graph_store_adapter, triplet_topk, llm_client, model_name):
def __init__(
self,
graph_store_adapter,
triplet_topk,
intent_interpreter,
text2gql,
):
"""Initialize Text Based Graph Retriever."""
self._graph_store_adapter = graph_store_adapter
self._triplet_topk = triplet_topk
self._intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
self._text2gql = Text2GQL(llm_client, model_name)
self._intent_interpreter = intent_interpreter
self._text2gql = text2gql
async def retrieve(self, text: str) -> Tuple[Graph, str]:
"""Retrieve from triplets graph with text2gql."""

View File

@@ -0,0 +1,70 @@
"""LocalText2GQL class."""
import json
import logging
import re
from typing import Dict, List, Union
from dbgpt.core import BaseMessage, HumanPromptTemplate
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.rag.transformer.llm_translator import LLMTranslator
LOCAL_TEXT_TO_GQL_PT = """
A question written in graph query language style is provided below. Given the question, translate the question into a cypher query that can be executed on the given knowledge graph. Make sure the syntax of the translated cypher query is correct.
To help query generation, the schema of the knowledge graph is:
{schema}
---------------------
Question: {question}
Query:
""" # noqa: E501
logger = logging.getLogger(__name__)
class LocalText2GQL(LLMTranslator):
"""LocalText2GQL class."""
def __init__(self, model_name: str):
"""Initialize the LocalText2GQL."""
super().__init__(OllamaLLMClient(), model_name, LOCAL_TEXT_TO_GQL_PT)
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
# translate intention to gql with single prompt only.
intention: Dict[str, Union[str, List[str]]] = json.loads(text)
question = intention.get("rewritten_question", "")
schema = intention.get("schema", "")
template = HumanPromptTemplate.from_template(self._prompt_template)
messages = (
template.format_messages(
schema=schema,
question=question,
history=history,
)
if history is not None
else template.format_messages(
schema=schema,
question=question,
)
)
return messages
def _parse_response(self, text: str) -> Dict:
"""Parse llm response."""
translation: Dict[str, str] = {}
query = ""
code_block_pattern = re.compile(r"```cypher(.*?)```", re.S)
result = re.findall(code_block_pattern, text)
if result:
query = result[0]
else:
query = text
translation["query"] = query.strip()
return translation

View File

@@ -382,9 +382,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
GraphElemType.ENTITY.value,
)
def delete_document(self, chunk_ids: str) -> None:
def delete_document(self, chunk_id: str) -> None:
"""Delete document in the graph."""
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
chunkids_list = [uuid.strip() for uuid in chunk_id.split(",")]
del_chunk_gql = (
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"

View File

@@ -210,6 +210,14 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=False,
description="Enable text2gql search or not.",
)
text2gql_model_enabled: bool = Field(
default=False,
description="Enable fine-tuned text2gql model for text2gql translation.",
)
text2gql_model_name: str = Field(
default=None,
description="LLM Model for text2gql translation.",
)
@register_resource(

View File

@@ -3,8 +3,9 @@
import logging
from typing import Tuple, Type
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt_ext.storage import __vector_store__ as vector_store_list
from dbgpt_ext.storage import _select_rag_storage
logger = logging.getLogger(__name__)
@@ -35,9 +36,9 @@ class VectorStoreFactory:
@staticmethod
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
for t in vector_store.__vector_store__:
for t in vector_store_list:
if t.lower() == vector_store_type.lower():
store_cls, cfg_cls = getattr(vector_store, t)
store_cls, cfg_cls = _select_rag_storage(t)
if issubclass(store_cls, VectorStoreBase) and issubclass(
cfg_cls, VectorStoreConfig
):