mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
GraphRAG: add new feature and bugfix (#2373)
Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com>
This commit is contained in:
@@ -50,26 +50,6 @@ class SystemParameters:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StorageVectorConfig(BaseParameters):
|
|
||||||
type: str = field(
|
|
||||||
default="Chroma",
|
|
||||||
metadata={
|
|
||||||
"help": _("default vector type"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StorageGraphConfig(BaseParameters):
|
|
||||||
type: str = field(
|
|
||||||
default="TuGraph",
|
|
||||||
metadata={
|
|
||||||
"help": _("default graph type"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageConfig(BaseParameters):
|
class StorageConfig(BaseParameters):
|
||||||
vector: VectorStoreConfig = field(
|
vector: VectorStoreConfig = field(
|
||||||
|
@@ -5,7 +5,10 @@ import os
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||||
|
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
|
||||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||||
|
from dbgpt_ext.rag.transformer.local_text2gql import LocalText2GQL
|
||||||
|
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
|
||||||
|
|
||||||
from ...transformer.text_embedder import TextEmbedder
|
from ...transformer.text_embedder import TextEmbedder
|
||||||
from .base import GraphRetrieverBase
|
from .base import GraphRetrieverBase
|
||||||
@@ -81,10 +84,25 @@ class GraphRetriever(GraphRetrieverBase):
|
|||||||
if "TEXT_SEARCH_ENABLED" in os.environ
|
if "TEXT_SEARCH_ENABLED" in os.environ
|
||||||
else config.enable_text_search
|
else config.enable_text_search
|
||||||
)
|
)
|
||||||
|
text2gql_model_enabled = (
|
||||||
|
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||||
|
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||||
|
else config.text2gql_model_enabled
|
||||||
|
)
|
||||||
|
text2gql_model_name = os.getenv(
|
||||||
|
"TEXT2GQL_MODEL_NAME",
|
||||||
|
config.text2gql_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
|
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
|
||||||
self._text_embedder = TextEmbedder(config.embedding_fn)
|
self._text_embedder = TextEmbedder(config.embedding_fn)
|
||||||
|
|
||||||
|
intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
|
||||||
|
if text2gql_model_enabled:
|
||||||
|
text2gql = LocalText2GQL(text2gql_model_name)
|
||||||
|
else:
|
||||||
|
text2gql = Text2GQL(llm_client, model_name)
|
||||||
|
|
||||||
self._keyword_based_graph_retriever = KeywordBasedGraphRetriever(
|
self._keyword_based_graph_retriever = KeywordBasedGraphRetriever(
|
||||||
graph_store_adapter, triplet_topk
|
graph_store_adapter, triplet_topk
|
||||||
)
|
)
|
||||||
@@ -95,7 +113,10 @@ class GraphRetriever(GraphRetrieverBase):
|
|||||||
similarity_search_score_threshold,
|
similarity_search_score_threshold,
|
||||||
)
|
)
|
||||||
self._text_based_graph_retriever = TextBasedGraphRetriever(
|
self._text_based_graph_retriever = TextBasedGraphRetriever(
|
||||||
graph_store_adapter, triplet_topk, llm_client, model_name
|
graph_store_adapter,
|
||||||
|
triplet_topk,
|
||||||
|
intent_interpreter,
|
||||||
|
text2gql,
|
||||||
)
|
)
|
||||||
self._document_graph_retriever = DocumentGraphRetriever(
|
self._document_graph_retriever = DocumentGraphRetriever(
|
||||||
graph_store_adapter,
|
graph_store_adapter,
|
||||||
|
@@ -4,10 +4,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
|
|
||||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||||
from dbgpt_ext.rag.retriever.graph_retriever.base import GraphRetrieverBase
|
from dbgpt_ext.rag.retriever.graph_retriever.base import GraphRetrieverBase
|
||||||
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,12 +13,18 @@ logger = logging.getLogger(__name__)
|
|||||||
class TextBasedGraphRetriever(GraphRetrieverBase):
|
class TextBasedGraphRetriever(GraphRetrieverBase):
|
||||||
"""Text Based Graph Retriever class."""
|
"""Text Based Graph Retriever class."""
|
||||||
|
|
||||||
def __init__(self, graph_store_adapter, triplet_topk, llm_client, model_name):
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_store_adapter,
|
||||||
|
triplet_topk,
|
||||||
|
intent_interpreter,
|
||||||
|
text2gql,
|
||||||
|
):
|
||||||
"""Initialize Text Based Graph Retriever."""
|
"""Initialize Text Based Graph Retriever."""
|
||||||
self._graph_store_adapter = graph_store_adapter
|
self._graph_store_adapter = graph_store_adapter
|
||||||
self._triplet_topk = triplet_topk
|
self._triplet_topk = triplet_topk
|
||||||
self._intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
|
self._intent_interpreter = intent_interpreter
|
||||||
self._text2gql = Text2GQL(llm_client, model_name)
|
self._text2gql = text2gql
|
||||||
|
|
||||||
async def retrieve(self, text: str) -> Tuple[Graph, str]:
|
async def retrieve(self, text: str) -> Tuple[Graph, str]:
|
||||||
"""Retrieve from triplets graph with text2gql."""
|
"""Retrieve from triplets graph with text2gql."""
|
||||||
|
@@ -0,0 +1,70 @@
|
|||||||
|
"""LocalText2GQL class."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
from dbgpt.core import BaseMessage, HumanPromptTemplate
|
||||||
|
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
|
||||||
|
from dbgpt.rag.transformer.llm_translator import LLMTranslator
|
||||||
|
|
||||||
|
LOCAL_TEXT_TO_GQL_PT = """
|
||||||
|
A question written in graph query language style is provided below. Given the question, translate the question into a cypher query that can be executed on the given knowledge graph. Make sure the syntax of the translated cypher query is correct.
|
||||||
|
To help query generation, the schema of the knowledge graph is:
|
||||||
|
{schema}
|
||||||
|
---------------------
|
||||||
|
Question: {question}
|
||||||
|
Query:
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalText2GQL(LLMTranslator):
|
||||||
|
"""LocalText2GQL class."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
"""Initialize the LocalText2GQL."""
|
||||||
|
super().__init__(OllamaLLMClient(), model_name, LOCAL_TEXT_TO_GQL_PT)
|
||||||
|
|
||||||
|
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
|
||||||
|
# translate intention to gql with single prompt only.
|
||||||
|
intention: Dict[str, Union[str, List[str]]] = json.loads(text)
|
||||||
|
question = intention.get("rewritten_question", "")
|
||||||
|
schema = intention.get("schema", "")
|
||||||
|
|
||||||
|
template = HumanPromptTemplate.from_template(self._prompt_template)
|
||||||
|
|
||||||
|
messages = (
|
||||||
|
template.format_messages(
|
||||||
|
schema=schema,
|
||||||
|
question=question,
|
||||||
|
history=history,
|
||||||
|
)
|
||||||
|
if history is not None
|
||||||
|
else template.format_messages(
|
||||||
|
schema=schema,
|
||||||
|
question=question,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _parse_response(self, text: str) -> Dict:
|
||||||
|
"""Parse llm response."""
|
||||||
|
translation: Dict[str, str] = {}
|
||||||
|
query = ""
|
||||||
|
|
||||||
|
code_block_pattern = re.compile(r"```cypher(.*?)```", re.S)
|
||||||
|
|
||||||
|
result = re.findall(code_block_pattern, text)
|
||||||
|
if result:
|
||||||
|
query = result[0]
|
||||||
|
else:
|
||||||
|
query = text
|
||||||
|
|
||||||
|
translation["query"] = query.strip()
|
||||||
|
|
||||||
|
return translation
|
@@ -382,9 +382,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
|||||||
GraphElemType.ENTITY.value,
|
GraphElemType.ENTITY.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_document(self, chunk_ids: str) -> None:
|
def delete_document(self, chunk_id: str) -> None:
|
||||||
"""Delete document in the graph."""
|
"""Delete document in the graph."""
|
||||||
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
|
chunkids_list = [uuid.strip() for uuid in chunk_id.split(",")]
|
||||||
del_chunk_gql = (
|
del_chunk_gql = (
|
||||||
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
||||||
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
||||||
|
@@ -210,6 +210,14 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Enable text2gql search or not.",
|
description="Enable text2gql search or not.",
|
||||||
)
|
)
|
||||||
|
text2gql_model_enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable fine-tuned text2gql model for text2gql translation.",
|
||||||
|
)
|
||||||
|
text2gql_model_name: str = Field(
|
||||||
|
default=None,
|
||||||
|
description="LLM Model for text2gql translation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_resource(
|
@register_resource(
|
||||||
|
@@ -3,8 +3,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
from dbgpt.storage import vector_store
|
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||||
|
from dbgpt_ext.storage import __vector_store__ as vector_store_list
|
||||||
|
from dbgpt_ext.storage import _select_rag_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,9 +36,9 @@ class VectorStoreFactory:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
||||||
for t in vector_store.__vector_store__:
|
for t in vector_store_list:
|
||||||
if t.lower() == vector_store_type.lower():
|
if t.lower() == vector_store_type.lower():
|
||||||
store_cls, cfg_cls = getattr(vector_store, t)
|
store_cls, cfg_cls = _select_rag_storage(t)
|
||||||
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
||||||
cfg_cls, VectorStoreConfig
|
cfg_cls, VectorStoreConfig
|
||||||
):
|
):
|
||||||
|
Reference in New Issue
Block a user